开始之前
参考项目:赛博魔杖_STM32卷积神经网络 - 立创开源硬件平台
立创开源链接:MagicWand-基于魔杖的智能家具控制 - 立创开源硬件平台
Github:moneypiaorui/MagicWand: 基于魔杖控制的智能家居系统
使用手册
硬件的组装和固件烧录见硬件手册,使用手册从固件烧录完成开始
配置WIFI
第一次使用时魔杖无法连接WIFI,会开启自身热点,名称为ESP32_Hotspot
用电脑连接该热点后,按住win+R,输入cmd打开命令行;然后输入ipconfig,下拉到无线局域网适配器一栏,复制默认网关这项的ip地址
在浏览器输入该ip地址,打开wifi配置界面,输入你的wifi名称和密码,点击保存,信息就会保存到esp32的存储中,下次启动的时候自动连接该WIFI
重启esp32,进入路由器的管理界面,检查esp32是否连接到wifi,连接成功会出现以下信息;新的ip地址可以记住,在后续的mqtt配置和开发中会使用到
Homeassistant和MQTT安装
Homeassistant(以下简称HA)是一个非常好用的智能家居平台,可以将米家和Homekit集成其中一起控制,自身还能拓展很多功能,DIY自由度非常高
MQTT是一个用于物联网设备的通信协议,本项目中魔杖将采用MQTT协议将动作指令发送给HA,然后由HA来实现具体的家居控制
Linux系统推荐使用宝塔面板进行安装,在宝塔的docker商店(需要9.2及以上才有docker)中搜索Homeassistant和emqx一键安装,省心省力
Windows系统安装
安装docker desktop
详见【从零开始】Docker Desktop:听说你小子要玩我-阿里云开发者社区
安装后打开命令行输入docker --version
,能够运行代表安装成功,且docker desktop呈现如下界面,重点是左下角是绿色的running状态
docker安装HA
先在命令行输入docker pull homeassistant/home-assistant:latest
拉取最新HA镜像
然后输入docker run -d --restart always --name homeassistant -v /data/homeassistant/config:/config -e TZ=Asia/Shanghai -p 8123:8123 homeassistant/home-assistant:latest
启动HA,端口为8123
HA的配置文件在docker->file栏下的config文件夹中
docker安装MQTT
使用emqx免费版,安装命令见官网下载 EMQX 开源版,结束状态如下
MQTT相关配置
esp32配置MQTT
在配置WIFI的地方我们能够获取到esp32的ip地址,同样粘贴到浏览器中,填写MQTT的相关设置然后保存
服务器填写对应电脑的ip地址,同样可以通过路由器控制面板查看/通过ipconfig查看;端口默认1883;用户名密码默认EMQX不验证,随便写。
EMQX配置(可选)
用来开启EMQX的账密验证服务;开启后HA配置MQTT集成,ESP32设置MQTT服务器都要配置正确的账号密码;一般不用开启,直接跳过这一部分就行
宝塔安装完成后使用localhost:18083(如果是在本机装的话)进入EMQX控制面板,账号密码默认是admin和public
点击左侧的安全->客户端授权->创建,然后一路下一步直到创建完成
然后点击 用户管理->添加 来配置用户
HA相关配置
宝塔安装完成后使用localhost:8123(如果是在本机装的话)进入HA页面,后续具体配置流程见低成本玩转智能家庭(一)Home Assistant搭建和配置 - 知乎
进入HA的dashboard后参考下面的文章配置MQTT集成(注意用户名and密码是可选的,默认EMQX没有开启账密验证,用户名and密码可以随便填)
homeassistant配置MQTT集成以及传感器实体(STM32连接进入homeassistant)_mqtt实体-CSDN博客
HA自动化
启动esp32,如果wifi和mqtt都配置完成,进入HA的MQTT集成界面(设置->设备与服务->集成);点击MQTT集成下的设备,能够看到Magic Wind的设备,点进去
然后点击MQTT INFO,能够看到ESP32通过MQTT发现自动配置的所有trigger(动作类型)
点击自动化右边的+号,点击“使用设备作为触发条件”,然后根据HA自动化的相关规则配置
魔杖使用
到这里所有前置环境的配置就结束了,下面是魔杖的具体使用方法
拨动开关打开魔杖,等到红灯完全熄灭不再闪烁或者长亮,表示初始化完成,可以使用。
按下按钮开始动作录制,红灯亮起,松开按钮结束动作录制,红灯熄灭,然后ESP32会判断动作并通知HA。
默认内置的12种动作,使用的时候注意在魔杖停止移动的时候 开始/结束,移动速度不宜太快,幅度不宜太小
训练动作
数据采集
Esp32开启websocket服务在8080端口,用esp32内置的webui(/dashboard.html)进行数据采集,通过ws连接esp32。当esp32的电容触摸引脚(或者按钮)被按下时,设置isRecording=true,ws发送“start”信息,然后开始用ws流式传输传感器数据,采样频率设置为100HZ。松开电容引脚(或者按钮),isRecording=false,并发送“end”告知一次数据采集结束。
网页会记录下每次采集的数据,并提供检查、选择、删除的功能,确认无误后导出为csv,作为一个动作的训练数据。采集新的动作数据时F5刷新网页,然后采集。
将每种动作的所有数据录制在同一个csv中,命名为<动作名.csv>,训练时根据.csv前面的文件名标识动作类名
模型训练
所有动作的csv在PC使用python进行预处理,训练测试集划分,模型训练保存
将所有动作的csv复制到train/data文件夹下,然后在train文件夹下右键打开终端,配置python环境后(推荐使用anoconda),输入python ANN.py运行神经网络的模型训练。
训练中自行根据正确率和loss下降趋势选择合适的epoch数量(或者就用默认的50epochs)
训练结束后会将one_hot编码和torch.nn的权重保存成json格式方便ESP32使用
模型导出
- 将train/model文件夹下的one_hot_encoder.json和model_weights.json复制到main/data文件夹下,使用LittleFS插件烧录到ESP32的flash中(需配置esp32c2开发板)
- 或者将train/model文件夹下的one_hot_encoder.json和model_weights.json通过配置网页上传到flash中
LittleFS安装见Arduino IDE 2:安装 ESP32 LittleFS 上传器 |随机书教程
硬件手册
PCB
电路板设计见立创开源平台
PCB主控采用ESP32C2系列ESP8684-WROOM-01C模组,本来打算用C3系列的8685-WROOM-01,但是市面上没有存货,而且考虑到PCB设计只有20cm的宽度,只好采用C2的模组了
但是C3系列的8685-WROOM-01在该电路上可以无缝替换,方便后面的编译环节
3D外壳
采用别人开源项目的外壳赛博魔杖_STM32卷积神经网络 - 立创开源硬件平台
固件烧录
代码采用Arduino开发,记得设置文件路径以保证相关库能正常使用
固件上传采用两种方式:
- (不建议使用,详细方法见开发手册->开发环境)使用Arduino编译并上传:此方法较为复杂,因为Arduino原生没有提供ESP32C2的支持,详细解决方法见Arduino中不支持ESP32C2 – SZU_TIC,如果出现任何问题请通过博客主页联系我
- 如果使用C3系列的8685-WROOM-01替换的话则可以正常编译,而且不需要修改代码
- 此方式还需自行烧录flash文件系统Arduino IDE 2:安装 ESP32 LittleFS 上传器 |随机书教程
- 使用官方烧录工具工具|乐鑫科技,将github仓库中/main/build/esp32.esp32.esp32c2下编译好的固件main.ino.merged.bin烧录到芯片的0x0中,然后将FS.bin烧录到0x290000中,具体操作见下图
- 或者也可以直接将/main/build/下合并好的target.bin烧录到0x0的位置,但Github仓库中的target.bin不保证最新
两种方法烧录固件时需要同时按下EN和BOOT按钮,然后先松开EN再松开BOOT进入下载模式
固件更新
开启了OTA更新功能,在webUI中选择bin文件上传即可进行更新
点击项目->导出已编译二进制文件,将相关bin文件导出到build文件夹下
能够更新的bin文件分别是main.ino.bin(主程序)和LittleFS编译出的FS.bin文件(文件系统),其他例如bootloader.bin之类的不要上传(FS.bin暂时无法OTA,不知道为什么QAQ)
开发手册
运行流程
- setup
- 读取WIFI和MQTT设置
- 连接WIFI,启动http服务和websocket(以下简称ws)服务
- 设置http路由,包括/WIFIconfig(设置WIFI),/MQTTconfig(设置MQTT),/restart(重启ESP32)
- 初始化MPU6050
- 设置MQTT
- LOOP
- 按下按钮时
- 通过MPU6050采集6轴动作,存储为SensorData结构体,push_back到actionRecord中,采样频率为100hz;
- 将六轴数据广播到所有的ws连接上
- 松开按钮时
- 将所有SensorData通过线性插值转成20个frame的数据集actionInput
- 将actionInput拉成120维的输入数据,通过手搓的前向传播和relu激活函数实现神经网络,得到12维(动作数)的output
- 取output中最大值的索引,通过actions[index]获得预测类别
- 将预测类别通过MQTT返回给HA
- 按下按钮时
代码介绍
ESP32
websocket
主要用于向webUI传输陀螺仪数据
web服务器
设置以下路由
- server.on("/WIFIconfig", handleWIFI);//处理WIFI表单上传
- server.on("/MQTTconfig", handleMQTT);//处理MQTT配置表单上传
- server.on("/restart", handleRestart);//处理重启请求
- server.on("/upload", HTTP_POST, handleFileUpload, handleUploadForm); // 处理文件上传请求
- server.on("/update", HTTP_POST, handleFirmwareUpload, handleUpdateBin); // 处理 OTA 请求
MQTT
自动配置为HA设备的触发器(trigger),具体规则见官方文档MQTT Discovery- Home Assistant
void publishDeviceConfiguration() {
// 通过MQTT发现,自动配置触发器(trigger)
for (String action : actions) {
Serial.print(action);
Serial.print(",");
StaticJsonDocument<512> doc; // 在每次循环中创建新的 JSON 对象
doc["automation_type"] = "trigger";
doc["type"] = "action";
doc["subtype"] = action; // 设置 subtype
doc["payload"] = action; // 设置 payload
doc["topic"] = controlTopic;
JsonObject device = doc.createNestedObject("device");
device["identifiers"][0] = "magic_wind"; // 使用小写和下划线
device["name"] = "Magic Wind";
// 将 JSON 转换为字符串
char jsonBuffer[512];
serializeJson(doc, jsonBuffer);
// 创建唯一的主题或 object_id
String uniqueId = "homeassistant/device_automation/" + String("magic_wind/") + action + "/config";
mqttClient.publish(uniqueId.c_str(), jsonBuffer); // 使用保留消息
}
Serial.println("\nDevice configuration published");
}
前向传播
在esp32上编写relu,softmax,全连接层来手动实现实现ANN的前向传播
// 定义 ReLU 激活函数
float relu(float x) {
return (x > 0) ? x : 0;
}
// Softmax 函数
void softmax(float* output, int length) {
float sum = 0.0;
// 计算所有输出值的指数和
for (int i = 0; i < length; i++) {
sum += exp(output[i]);
}
// 将每个输出转化为概率
for (int i = 0; i < length; i++) {
output[i] = exp(output[i]) / sum;
}
}
// 线性插值函数
float linear_interpolation(float target_time, float time0, float time1, float value0, float value1) {
return value0 + (value1 - value0) * (target_time - time0) / (time1 - time0);
}
// 全连接层操作
float fc(float* input, float* weights, float* bias, int input_size) {
float output = 0;
for (int i = 0; i < input_size; i++) {
output += input[i] * weights[i];
}
return output + bias[0];
}
// 前向传播逻辑
void forward(float* input, float* output, float* fc1_weight, float* fc1_bias, float* fc3_weight, float* fc3_bias) {
float hidden[layer2];
// 第一层全连接 (fc1)
for (int i = 0; i < layer2; i++) {
hidden[i] = 0;
for (int j = 0; j < layer1; j++) {
hidden[i] += input[j] * fc1_weight[i * layer1 + j];
}
hidden[i] += fc1_bias[i];
hidden[i] = relu(hidden[i]); // ReLU 激活函数
}
// 第三层全连接 (fc3)
for (int i = 0; i < layer3; i++) {
output[i] = 0;
for (int j = 0; j < layer2; j++) {
output[i] += hidden[j] * fc3_weight[i * layer2 + j];
}
output[i] += fc3_bias[i];
}
}
从flash中的json文件加载权重和分类信息字典,使用ArduinoJson库读取,StaticJsonDocument后为内存空间大小需要手动设置,注意不要过大不然可能堆栈溢出导致无限重启
// 加载模型权重
void loadModelWeights(const char* filename) {
File file = LittleFS.open(filename, "r");
if (!file) {
Serial.println("Failed to open model-weights file");
return;
}
// 解析 JSON
StaticJsonDocument<121024> doc;
DeserializationError error = deserializeJson(doc, file);
if (error) {
Serial.println("Failed to read model-weights");
return;
}
// 从 JSON 中加载权重
for (int i = 0; i < layer2; i++) {
fc1_bias[i] = doc["fc1.bias"][i];
for (int j = 0; j < layer1; j++) {
fc1_weight[i * layer1 + j] = doc["fc1.weight"][i][j];
}
}
for (int i = 0; i < layer3; i++) {
fc3_bias[i] = doc["fc3.bias"][i];
for (int j = 0; j < layer2; j++) {
fc3_weight[i * layer2 + j] = doc["fc3.weight"][i][j];
}
}
file.close();
}
// 加载 OneHotEncoder 类别信息
void loadOneHotEncoder(const char* filename) {
File file = LittleFS.open(filename, "r");
if (!file) {
Serial.println("Failed to open OneHotEncoder file");
return;
}
// 解析 JSON
StaticJsonDocument<1024> doc;
DeserializationError error = deserializeJson(doc, file);
if (error) {
Serial.println("Failed to read OneHotEncoder");
return;
}
// 清空 actions vector,确保每次加载时都是干净的
actions.clear();
// 读取类别信息
for (JsonVariant category : doc.as<JsonArray>()[0].as<JsonArray>()) {
actions.push_back(category.as<String>());
}
file.close();
}
Python
数据预处理
数据预处理采用线性插值降维实现维度统一,然后拼接成6通道的(6*target_frame)vector输入ANN。代码在interpolate_data.py
# 线性插值
def interpolate_data(data, target_frames=100):
interpolated_data = []
# 根据 ID 分组
grouped = data.groupby('id')
for id_value, group in grouped:
# 获取时间和特征
time = group['time'].values
Ax = group['Ax'].values
Ay = group['Ay'].values
Az = group['Az'].values
gx = group['gx'].values
gy = group['gy'].values
gz = group['gz'].values
# 仅在数据点数量足够时进行插值
if len(time) > 1: # 确保有足够的数据点进行插值
target_time = np.linspace(time.min(), time.max(), target_frames)
# 进行插值
Ax_interp = np.interp(target_time, time, Ax)
Ay_interp = np.interp(target_time, time, Ay)
Az_interp = np.interp(target_time, time, Az)
gx_interp = np.interp(target_time, time, gx)
gy_interp = np.interp(target_time, time, gy)
gz_interp = np.interp(target_time, time, gz)
# 将插值结果添加到列表
for t, ax, ay, az, gxi, gyi, gzi in zip(target_time, Ax_interp, Ay_interp, Az_interp, gx_interp, gy_interp, gz_interp):
interpolated_data.append([id_value, t, ax, ay, az, gxi, gyi, gzi])
# 转换为 DataFrame
interpolated_df = pd.DataFrame(interpolated_data, columns=['id', 'time', 'Ax', 'Ay', 'Az', 'gx', 'gy', 'gz'])
return interpolated_df
网络架构
model.py中实现了两个网络架构,3层MLP网络和两层卷积+一层全连接的CNN;默认使用MLP网络因为ESP32的前向传播只实现了MLP部分,卷积逻辑没有实现
# 构建神经网络模型
class ActionClassifier(nn.Module):
def __init__(self,input_shape,output_classes):
super(ActionClassifier, self).__init__()
self.fc1 = nn.Linear(input_shape[1]*input_shape[2], 32) # 输入特征数
# self.fc2 = nn.Linear(32, 32)
self.fc3 = nn.Linear(32, output_classes) # 输出层节点数为CSV文件数量
def forward(self, x):
x = torch.flatten(x,start_dim=-2)
x = self.fc1(x)
x = torch.relu(x)
# x = torch.nn.functional.leaky_relu(self.fc1(x), negative_slope=0.01)
x = self.fc3(x)
return x
# CNN模型
class ActionClassifierCNN(nn.Module):
def __init__(self,input_shape,output_classes):
super(ActionClassifierCNN, self).__init__()
# 卷积层
self.conv1 = nn.Conv1d(in_channels=input_shape[2], out_channels=30, kernel_size=3, stride=3, padding=1)
self.conv2 = nn.Conv1d(in_channels=30, out_channels=15, kernel_size=3, stride=3, padding=1)
# 全连接层
self.fc1 = nn.Linear(15 * ((input_shape[1]-1)//9+1), output_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = x.permute(0, 2, 1)
x = torch.nn.functional.leaky_relu(self.conv1(x))
x = torch.nn.functional.leaky_relu(self.conv2(x))
x = torch.flatten(x,start_dim=-2)# 倒数第二个维度开始,批量和单个数据都可以运行
x = self.fc1(x)
x = self.dropout(x)
return x # 分类问题使用 softmax 输出
模型训练
训练代码在ANN.py中,可以调整参数如下
- target_frame:线性插值降维的目标frame,*6是真正降维后输入ANN的数据维度
- test_size:划分训练测试集的时候测试集的比例
- num_epochs:训练轮次(所有训练数据都通过nn进行一次反向传播称为一个epoch)
使用交叉熵损失函数,在降维维度为20的时候进行40个epoch训练的loss和accuracy曲线如下,可以发现在25个epoch后模型趋于收敛
模型部署
因为esp32上无法使用pytorch相关库,所以采取将权重json序列化保存后上传到esp32闪存中,同时保存传one-hot字典用于将ANN输出值softmax后解码
模型预测的时候使用softmax将model输出值映射到0-100%之间,取>80%的输出为预测结果,并通过one-hot字典解码成动作string后上报到MQTT服务器
模型优化
因为ESP32的内存只有328KB,为了避免内存堆栈溢出,需要对模型进行压缩
要对模型进行压缩和加速可以从降低线性插值维度和改变ANN结构来实现,下面主要尝试了降低线性插值维度。
设置frame从2到20,每次训练40个epoch,使用相同的训练测试集,得到loss和accuracy关于frame-epoch平面下的曲面图如下
可以发现插值维度10~20训练速度和accuracy差异不大,所以将插值维度降低为10,测得模型权重导出的json文件从90.6KB下降到49.2KB,压缩46%
开发环境
本项目使用Arduino2.0开发(不是1.0),相关依赖已放在Github仓库中,即根目录/libraries;点击main/main.ino自动打开项目文件夹,然后点击文件->首选项,修改项目文件夹为该项目根目录,即可在编译时自动查找依赖
在Arduino左侧开发板管理器中下载3.0.4版本的esp32开发板(注意版本要对),然后参考Arduino中不支持ESP32C2 – SZU_TIC开启ESP32C2的支持,在顶部选择对应串口和开发板
然后点击“工具”栏,配置flash Size和Flash Scheme(主要是OTA,app,文件系统等分区的配置),注意自己购买的的8684WROOM-1C后缀是H4,代表内置4MB flash
然后就可以正常编译上传了,data下的文件使用LittleFS Uploader上传,具体安装见Arduino IDE 2:安装 ESP32 LittleFS 上传器 |随机书教程
固件合并
点击项目->导出已编译二进制文件,将相关bin文件导出到build文件夹下
下面是Arduino编译时合并固件的命令,可以看到对应bin所在分区,按照这个配置flash_tool,其中boot_app0.bin可以忽略不合并,LittleFS固件地址和获取在下文介绍;配置完成点击combineBin进行合并
"C:\\Users\\24518\\AppData\\Local\\Arduino15\\packages\\esp32\\tools\\esptool_py\\4.6/esptool.exe" --chip esp32c2 merge_bin -o "C:\\Users\\24518\\AppData\\Local\\Temp\\arduino\\sketches\\BFDA5DF393641D20E0648E7CFC4A91E2/main.ino.merged.bin" --fill-flash-size 2MB --flash_mode keep --flash_freq keep --flash_size keep 0x0 "C:\\Users\\24518\\AppData\\Local\\Temp\\arduino\\sketches\\BFDA5DF393641D20E0648E7CFC4A91E2/main.ino.bootloader.bin" 0x8000 "C:\\Users\\24518\\AppData\\Local\\Temp\\arduino\\sketches\\BFDA5DF393641D20E0648E7CFC4A91E2/main.ino.partitions.bin" 0xe000 "C:\\Users\\24518\\AppData\\Local\\Arduino15\\packages\\esp32\\hardware\\esp32\\3.0.4/tools/partitions/boot_app0.bin" 0x10000 "C:\\Users\\24518\\AppData\\Local\\Temp\\arduino\\sketches\\BFDA5DF393641D20E0648E7CFC4A91E2/main.ino.bin"
在使用LittleFS上传的时候观察命令行可以发现有这么一段输出
Sketch Path: E:\Project\electronicDIY\MagicWand\main
Data Path: E:\Project\electronicDIY\MagicWand\main\data
Device: ESP32 series, model esp32c2
Using partition: default
Partitions: C:\Users\24518\AppData\Local\Arduino15\packages\esp32\hardware\esp32\3.0.4\tools\partitions\default.csv
Start: 0x290000
End: 0x3f0000
打开对应的CSV文件结构如下,存储flash中各种固件起始位置,而且根据开发板,flash Size和Flash Scheme设置不同,分区表也不同;对于该项目使用的esp32c2,4MBflash,default配置,文件系统地址为0x290000
# Name, Type, SubType, Offset, Size, Flags
nvs, data, nvs, 0x9000, 0x5000,
otadata, data, ota, 0xe000, 0x2000,
app0, app, ota_0, 0x10000, 0x140000,
app1, app, ota_1, 0x150000,0x140000,
spiffs, data, spiffs, 0x290000,0x160000,
coredump, data, coredump,0x3F0000,0x10000,
除此之外命令行还有一条信息
Command Line: C:\Users\24518\AppData\Local\Arduino15\packages\esp32\tools\esptool_py\4.6\esptool.exe --chip esp32c2 --port COM5 --baud 921600 --before default_reset --after hard_reset write_flash -z --flash_mode dio --flash_freq 60m --flash_size detect 2686976 C:\Users\24518\AppData\Local\Temp\tmp-40156-RVtpLUV4YNGT-.littlefs.bin
最后的地址就代表文件系统固件存放位置
Comments 4 条评论
博主 巫枫-
Warning: 通过Sakurairo获取IP地理位置失败:返回的数据不是json格式 in /www/wordpress/wwwroot/wp-content/themes/Sakurairo/inc/classes/IpLocation.php on line 58
中国 广西 Hechi
这个pcb为什么我插在电脑上,没有显示串口
博主 ChainPray
Warning: 通过Sakurairo获取IP地理位置失败:返回的数据不是json格式 in /www/wordpress/wwwroot/wp-content/themes/Sakurairo/inc/classes/IpLocation.php on line 58
中国 广东 深圳
@巫枫- 得下ch340串口芯片的驱动 https://www.wch.cn/downloads/CH341SER_EXE.html
博主 巫枫-
Warning: 通过Sakurairo获取IP地理位置失败:返回的数据不是json格式 in /www/wordpress/wwwroot/wp-content/themes/Sakurairo/inc/classes/IpLocation.php on line 58
中国 广西 Hechi
有相关的群吗,我不太会烧录固件,遇到了很多问题
博主 ChainPray
Warning: 通过Sakurairo获取IP地理位置失败:返回的数据不是json格式 in /www/wordpress/wwwroot/wp-content/themes/Sakurairo/inc/classes/IpLocation.php on line 58
中国 广东 深圳
@巫枫- 暂时没建群,你直接加我主页v或者QQ吧