• 基于小波分析与深度学习的脑电信号分类(matlab)


    原理

    通过小波变换对运动想象信号进行特征提取,生成时频图像作为神经网络的输入。

    实现

    使用BCI竞赛2008–Graz dataset A中的A01受试者的数据作为数据集。

    采样

    采样频率250Hz,每个数据前三个为伪迹参考信号,后6个为EEG信号集,
    每一个划成48次,四个任务,每个任务12次;
    每次任务大概8s,每次大概从3-6s(750-1500点)为运动想像时间,
    拟采集770-1462点,每个样本512个点,采样间隔20个点,采集25个样本;

    %% 采样信号
    clear;
    fs = 250;     % 采样频率 250Hz
    Na = 96735;   %
    Nt = 48;      % 一个EEG信号集划成48次数据,即四个任务、每个任务12次
    Ns = 25;      % 样本数 25个
    Np = 20;      % 采样间隔20个点
    N  = 256;
    %%
    x00 = load('A01T');
    %%
    for k=1:6  % 后6个为EEG信号集,即data {1,4} {1,5} {1,6} {1,7} {1,8} {1,9}
        x01 = x00.data{1, k+3}.X;    % EEG信号
        y01 = x00.data{1, k+3}.y;    % 类别
        t = x00.data{1, k+3}.trial;  % 试验(trials),包含伪迹
        t(Nt+1) = Na;
        %figure
        for i = 1:Nt
            x0 = x01(t(i):t(i+1), :);
            %subplot(6,8,i);
            %plot(x0(:,1));xlim([0 2100]);ylim([-100 100]);
            for j = 1:Ns
                x1 = x0(750+Np*(j-1):750+Np*(j-1)+N-1, 1:22);
                x2 = (x1-min(x1(:)))/(max(x1(:))-min(x1(:)));  % 最大最小归一化
                XTr(:, :, 1, 1200*(k-1)+25*(i-1)+j) = x2;
                YTr(1, 1200*(k-1)+25*(i-1)+j) = categorical(y01(i));
            end
            clear x0; % 每次迭代x0的长度会发生变化
        end
    end
    
    save SubA_Train XTr YTr;
    

    请添加图片描述

    小波变换

    %% 小波变换
    clear
    load SubA_Train;
    %%
    
    id=[8 10 12];  % 选三个电极,
    parfor i=1:length(XTr)
        for j=1:3 
            x = XTr(:,id(j),1,i);
            x1 = abs(cwt(x));  % 小波变换
            XTrft(:,:,j,i) = (x1-min(x1(:)))/(max(x1(:))-min(x1(:)));   % 归一化
        end     
    end
    
    save SubA_TF_Train XTrft YTr;
    
    %% 可视化一个样本为彩色图片
    size(XTrft(:,:,:,1))  % 51×256×3
    categories(YTr)  % 查看类别数
    
    figure;
    imshow(XTrft(:,:,:,1))
    

    时频图样张:
    请添加图片描述

    把时频图保存到本地文件夹

    • 图片的尺寸为51×256×3
    • 4个文件夹(0、1、2、3),每个文件夹的图片为同一类信号
    %% 转成图片格式,先新建一个images文件夹,然后在images里面新建4个文件夹,分别为0、1、2、3.
    load SubA_TF_Train
    for i = 1:7200
        k = double(string(YTr(1,i)))-1;  % label
        imwrite(XTrft(:,:,:,i),['images\',num2str(k)','\',num2str(i),'.jpg'])  % 保存为图片
    end
    

    训练和评估

    利用deepNetworkDesigner搭建网络,导出到工作区,训练。需要注意的是,网络的输出层为4类。可以采用典型的网络,例如Googlenet、resnet等。

    clear;
    
    %% 导入数据集
    imdsTrain = imageDatastore("images","IncludeSubfolders",true,"LabelSource","foldernames");
    [imdsTrain, imdsValidation] = splitEachLabel(imdsTrain,0.8,"randomized");
    
    % 调整图像大小以匹配网络输入层
    % inputsize = [256 256 3];
    inputsize = [51 256 3];
    augimdsTrain = augmentedImageDatastore(inputsize,imdsTrain);
    augimdsValidation = augmentedImageDatastore(inputsize,imdsValidation);
    
    %% 网络结构alexnet
    % Net = alexnet;
    % Net = googlenet;
    % Net = inceptionresnetv2;
    deepNetworkDesigner
    
    %% 训练网络
    miniBatchSize = 128;
    learnRate = 0.0001;
    valFrequency = floor(0.8*7200.0/miniBatchSize);
    options = trainingOptions('adam', ...
        'InitialLearnRate',learnRate, ...
        'MaxEpochs',20, ...
        'MiniBatchSize',miniBatchSize, ...
        'Shuffle','every-epoch', ...
        'Plots','training-progress', ...
        'Verbose',false, ...
        'ValidationData',augimdsValidation, ...
        'ValidationFrequency',valFrequency, ...
        'LearnRateSchedule','piecewise', ...
        'LearnRateDropFactor',0.1, ...
        'LearnRateDropPeriod',5);
    trainedNet = trainNetwork(augimdsTrain, lgraph_1, options);
    
    %% 评估
    % 准确率
    % 训练集
    [YPred,probs] = classify(trainedNet,augimdsTrain);
    accuracy = mean(YPred == imdsTrain.Labels)
    disp("training acc: " + accuracy*100 + "%")
    % 验证集
    [YPred,probs] = classify(trainedNet,augimdsValidation);
    accuracy = mean(YPred == imdsValidation.Labels)
    disp("val acc: " + accuracy*100 + "%")
    
    % 混淆矩阵
    figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
    cm = confusionchart(imdsValidation.Labels,YPred);
    cm.Title = 'Confusion Matrix for Validation Data';
    cm.ColumnSummary = 'column-normalized';
    cm.RowSummary = 'row-normalized';
    

    结果

    • alexnet
      • training acc: 99.3924%
      • val acc: 84.7917%
    • resnet
      • training acc: 100%
      • val acc: 85.4861%
    • simplenet(我自己搭建的网络)
      • training acc: 98.7674%
      • val acc: 90.0694%

    请添加图片描述请添加图片描述

    python版本的数据集

    • https://github.com/bregydoc/bcidatasetIV2a
  • 相关阅读:
    爬取小说章节,并制作成词云进行宣传
    基本初等函数
    如何应对量化策略的失效
    如何将几张图片转换为GIF动图?
    Unity实现设计模式——备忘录模式
    VS Code配置c++环境
    测试部门来了个00后卷王之王,老油条感叹真干不过,但是...
    VBA实战(11) - 工作表(Sheet) 操作汇总
    基于JavaSwing开发书店管理系统+论文 毕业设计 课程设计 大作业
    动态规划-01背包问题新解(c)
  • 原文地址:https://blog.csdn.net/qq_50258800/article/details/127045577