• 基于CSP的运动想象EEG分类任务实战


    基于运动想象的公开数据集:Data set IVa (BCI Competition III)1
    数据描述参考前文:https://blog.csdn.net/qq_43811536/article/details/134224005?spm=1001.2014.3001.5501
    EEG 信号时频空域分析参考前文:https://blog.csdn.net/qq_43811536/article/details/134273470?spm=1001.2014.3001.5501
    基于CSP的运动想象 EEG 特征提取和可视化参考前文:https://blog.csdn.net/qq_43811536/article/details/134296308?spm=1001.2014.3001.5501
    CSP(Common Spatial Patterns)——EEG特征提取方法详解参考前文:https://blog.csdn.net/qq_43811536/article/details/134296840?spm=1001.2014.3001.5501

    本文使用公开数据集 Data set IVa 中的部分被试数据,数据已公开可以从网盘获取:
    链接:https://pan.quark.cn/s/5425ee5918f4
    提取码:hJFz



    1. 实验介绍

    本任务的实验数据来自一名健康受试者,代号al。受试者在视觉提示出现后3.5s内完成以下3个运动想象中的一个:(L)左手,(R)右手,(F)右脚。分类任务中的数据只包括了右手和右脚两类,共280个试次。实验过程中使用脑电帽记录了118个通道的EEG信号,电极位置如图1所示。采集到的EEG信号首先经过带通滤波(0.05-200Hz),再经过数字化和下采样,得到采样率为100Hz的信号。

    在这里插入图片描述

    图1 电极位置

    2. 运动想象分类

    基于CSP特征,我们使用LDA分类器进行分类,并进行十折交叉验证以评估性能。评价指标为测试集准确率,即分类正确的试次占总试次的比例。

    2.1 分类性能

    我们比较了不同的带通滤波器和时间窗的结果。

    • 图1中,横轴为时间窗相较于提示出现的起始时间。不同的折线代表了不同窗长。我们发现在3s的窗长能获得更高的分类准确率,时间窗从提示出现后0.5s开始效果更好,分类准确率达到1。
    • 图2展示了滤波器截止频率对于准确率的影响,可以看到低频截止频率在10-12Hz时准确率能达到1。
    • 我们还比较了LDA分类器与线性回归(LR)和随机森林(RF)方法的性能,结果如表1所示。LDA分类器的准确率高于LR和RF,但分类性能都较高。
    • 最后我们去掉提取CSP特征的模块,直接对原始信号使用LDA分类器,结果如图3所示。去除掉提取CSP模块后,分类准确率由1下降至0.6左右。

    在这里插入图片描述

    图1 时间窗搜索结果。横轴为时间窗相较于提示出现的起始时间。不同的折线代表了不同窗长。

    在这里插入图片描述

    图2 滤波器参数搜索结果。横轴为低频截止频率,带宽固定为16Hz。

    表1 不同分类器的分类结果
    方法准确率
    LDA1
    LR0.99±0.01
    RF0.99±0.01

    在这里插入图片描述

    图3 消融实验。直接对原始信号使用LDA性能较差。

    2.2 结论

    实验表明,右手和右脚运动想象的EEG差异集中于μ节律信号(8-15Hz)和β节律(18-24Hz),体现在C3和C4通道,即感觉运动区。使用CSP算法提取到的特征具有较高的线性可分性,使用LDA分类器可以实现准确率为1,能有效区分这两类运动想象。实验发现用于分类任务的时间窗范围和带通滤波范围对分类准确率具有较大影响,最优时间窗为提示出现后0.5s-3.5s,最优频带为12Hz-28Hz。


    3. 核心Python代码

    • 部分变量说明:
      • raw:由 mne.io.RawArray() 函数创建,代表原始EEG数据
      • epochs:由 mne.Epochs() 函数创建,代表一个事件(event)对应的所有数据,在该数据集中一个事件即 “右手”或者“脚”的想象运动
    # BP Filter
    l_fr, h_fr = 12.0, 28.0
    tMin, tMax = 0.5, 3.5
    
    # MNE object
    info = mne.create_info(
        ch_names=[i[0] for i in ch_name],
        sfreq=eeg_fs,
        ch_types='eeg')
    pos_dic = dict(zip(info.ch_names, ch_pos))
    montage = mne.channels.make_dig_montage(pos_dic)
    
    info.set_montage(montage)
    raw = mne.io.RawArray(eeg_data.T, info)
    # Apply band-pass filter
    raw.filter(l_fr, h_fr, fir_design="firwin", skip_by_annotation="edge")
    
    # Decoding
    
    
    events = np.vstack((cues_pos, np.zeros(len(cues_pos)), target_label[0, :])).T.astype(int)
    picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
    
    # Epochs
    epochs = mne.Epochs(
        raw,
        events,
        events_id,
        tMin,
        tMax,
        proj=True,
        picks=picks,
        baseline=None,
        preload=True,
    )
    
    # Prepare data for training
    x = epochs.get_data()
    y = target_label[0, :]
    
    # ten-fold cross-validation
    cv = ShuffleSplit(10, test_size=test_r, random_state=42)
    
    # Classification with LDA on CSP features
    lda = LinearDiscriminantAnalysis()
    csp = CSP(n_components=10, reg=None, log=True, norm_trace=False)
    clf = Pipeline([("CSP", csp), ("LDA", lda)])
    
    from sklearn.metrics import accuracy_score
    
    train_x, test_x = x[:224], x[224:]
    train_y, test_y = y[:224], y[224:]
    
    clf.fit(train_x,train_y)
    
    pred1 = clf.predict(train_x)
    accuracy1 = accuracy_score(train_y,pred1)
    print('在训练集上的精确度: %.4f'%accuracy1)
    
    pred2 = clf.predict(test_x)
    accuracy2 = accuracy_score(test_y,pred2)
    print('在测试集上的精确度: %.4f'%accuracy2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    # 模型比较
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.ensemble import RandomForestClassifier
    
    lda = LinearDiscriminantAnalysis()
    csp = CSP(n_components=10, reg=None, log=True, norm_trace=False)
    clf_lda = Pipeline([("CSP", csp), ("LDA", lda)])
    scores_lda = cross_val_score(clf_lda, x, y, cv=cv, n_jobs=None)
    
    lr = LogisticRegression()
    csp = CSP(n_components=10, reg=None, log=True, norm_trace=False)
    clf_lr = Pipeline([("CSP", csp), ("LR", lr)])
    scores_lr = cross_val_score(clf_lr, x, y, cv=cv, n_jobs=None)
    
    rfc = RandomForestClassifier()
    csp = CSP(n_components=10, reg=None, log=True, norm_trace=False)
    clf_rfc = Pipeline([("CSP", csp), ("RFC", rfc)])
    scores_rfc = cross_val_score(clf_rfc, x, y, cv=cv, n_jobs=None)
    print(scores_lda, scores_lr, 'scores_svc', scores_rfc)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    # Without CSP
    lda = LinearDiscriminantAnalysis()
    scores_lda_only = cross_val_score(lda, x.reshape(-1,118*301), y, cv=cv, n_jobs=None)
    print(scores_lda_only)
    
    plt.plot(scores_lda,'-o',linewidth=2)
    plt.plot(scores_lda_only,'-d',linewidth=2)
    plt.xlabel('Folds',fontsize=16)
    plt.ylabel('Accuracy',fontsize=16)
    plt.legend(['CSP+LDA','LDA'],fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.ylim([0,1.1])
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    1. https://bbci.de/competition/iii/desc_IVa.html ↩︎

  • 相关阅读:
    redis(4)-hiredis-API函数的调用
    VMware中安装centos无网络,配置教程
    XML入门介绍
    0-5V转4-20mA电路
    Office 2021 小型企业版商用办公软件评测:提升工作效率与协作能力的专业利器
    【触想智能】工控一体机与5G物联网技术结合是未来发展趋势
    【编程题】【Scratch三级】2019.12 判断奇偶数
    GPT对话代码库——HAL库下 USART 的配置及问题(STM32G431CBT6)
    【数据仓库基础(二)】数据仓库架构
    《微服务架构设计模式》第二章
  • 原文地址:https://blog.csdn.net/qq_43811536/article/details/134297508