• configs


    configs 部分

    ```python
    import os  # 导入os模块,用于系统级操作

    emotion = ["Valence"]  # 定义情绪列表,只包含情绪维度"Valence"

    # 配置参数字典
    config = {
        "extract_class_label": 1,  # 是否提取类别标签
        "extract_continuous_label": 1,  # 是否提取连续标签
        "extract_eeg": 1,  # 是否提取EEG数据
        "eeg_folder": "eeg",  # 存放EEG数据的文件夹名称
        "eeg_config": {  # EEG数据处理的详细配置
            "sampling_frequency": 256,  # 采样频率
            "window_sec": 2,  # 窗口长度(秒)
            "hop_sec": 0.25,  # 跳跃长度(秒)
            "buffer_sec": 5,  # 缓冲区长度(秒)
            "num_electrodes": 32,  # 电极数量
            "interest_bands": [(0.3, 4), (4, 8), (8, 12), (12, 18), (18, 30), (30, 45)],  # 感兴趣频段
            "f_trans_interest_bands": [(0.1, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)],  # 感兴趣频段的过渡频率
            "channel_slice": {'eeg': slice(0, 32), 'ecg': slice(32, 35), 'misc': slice(35, -1)},  # 通道切片
            "features": ["eeg_bandpower"],  # 特征
            "filter_type": 'cheby2',  # 滤波器类型
            "filter_order": 4  # 滤波器阶数
        },
        "save_npy": 1,  # 是否保存为.npy格式的数据
        "npy_folder": "compacted_48",  # 存放.npy数据的文件夹名称
        "dataset_name": "mahnob",  # 数据集的名称
        "emotion_list": emotion,  # 情绪列表
        "root_directory": r"D:\DingYi\Dataset\MAHNOB-O",  # 原始数据集的根目录路径
        "output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R",  # 处理后数据的输出根目录路径
        "raw_data_folder": "Sessions",  # 原始数据存放的文件夹名称
        "multiplier": {  # 不同数据类型的倍增因子
            "video": 16,
            "eeg_raw": 1,
            "eeg_bandpower": 1,
            "eeg_DE": 1,
            "eeg_RP": 1,
            "eeg_Hjorth": 1,
            "continuous_label": 1
        },
        "feature_dimension": {  # 不同特征的维度信息
            "eeg_raw": (16384,),
            "eeg_bandpower": (192,),
            "eeg_DE": (192,),
            "eeg_RP": (192,),
            "eeg_Hjorth": (96,),
            "continuous_label": (1,),
            "class_label": (1,)
        },
        "max_epoch": 15,  # 最大的训练周期数
        "min_epoch": 0,  # 最小的训练周期数
        "model_name": "2d1d",  # 模型的名称
        "backbone": {  # 模型的骨干网络配置
            "state_dict": "res50_ir_0.887",
            "mode": "ir"
        },
        "early_stopping": 10,  # 提前停止训练的步数
        "load_best_at_each_epoch": 1,  # 是否在每个周期加载最佳模型
        "time_delay": 0,  # 时间延迟
        "metrics": ["rmse", "pcc", "ccc"],  # 评估指标
        "save_plot": 0  # 是否保存图形结果
    }
    ```

    这段代码是一个Python字典,包含了各种配置参数,用于处理和分析一个名为MAHNOB的数据集,主要用于情绪识别研究。以下是每行代码的解释:

    1. `import os`: 导入Python的os模块,用于操作文件路径等系统级操作。

    2. `emotion = ["Valence"]`: 定义一个情绪列表,只包含情绪维度"Valence"。

    3. `config = { ... }`: 定义一个名为config的字典,包含了各种配置参数。

    4. `"extract_class_label": 1`: 是否提取类别标签,这里设为1表示是。

    5. `"extract_continuous_label": 1`: 是否提取连续标签,这里设为1表示是。

    6. `"extract_eeg": 1`: 是否提取EEG数据,这里设为1表示是。

    7. `"eeg_folder": "eeg"`: 存放EEG数据的文件夹名称。

    8. `"eeg_config": { ... }`: EEG数据处理的详细配置,包括采样频率、窗口长度、跳跃长度、通道数量等参数。

    9. `"save_npy": 1`: 是否保存为.npy格式的数据,这里设为1表示是。

    10. `"npy_folder": "compacted_48"`: 存放.npy数据的文件夹名称。

    11. `"dataset_name": "mahnob"`: 数据集的名称。

    12. `"emotion_list": emotion`: 情绪列表,使用了之前定义的emotion变量。

    13. `"root_directory": r"D:\DingYi\Dataset\MAHNOB-O"`: 原始数据集的根目录路径。

    14. `"output_root_directory": r"D:\DingYi\Dataset\MAHNOB-P-R"`: 处理后数据的输出根目录路径。

    15. `"raw_data_folder": "Sessions"`: 原始数据存放的文件夹名称。

    16. `"multiplier": { ... }`: 不同数据类型的倍增因子,用于数据增强或者调整数据量。

    17. `"feature_dimension": { ... }`: 不同特征的维度信息,用于数据处理和模型输入。

    18. `"max_epoch": 15`: 最大的训练周期数。

    19. `"min_epoch": 0`: 最小的训练周期数。

    20. `"model_name": "2d1d"`: 模型的名称,这里只是命名用途,实际上没有使用。

    21. `"backbone": { ... }`: 模型的骨干网络配置,包括状态字典和模式。

    22. `"early_stopping": 10`: 提前停止训练的步数。

    23. `"load_best_at_each_epoch": 1`: 是否在每个周期加载最佳模型。

    24. `"time_delay": 0`: 时间延迟,用于连续标签在数据点中的移动。

    25. `"metrics": ["rmse", "pcc", "ccc"]`: 评估指标,包括均方根误差、皮尔逊相关系数和一致性相关系数。

    26. `"save_plot": 0`: 是否保存图形结果,这里设为0表示否。

    这些配置参数用于设置数据预处理、模型训练和评估过程中的各种选项和参数,确保流程能够顺利进行和有效执行。

    from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
    from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
    from base.label_config import *  # 导入标签配置

    import os  # 导入os模块,用于系统级操作
    import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

    import pandas as pd  # 导入pandas库,用于数据处理和分析
    import numpy as np  # 导入numpy库,用于数值计算

    import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件

    generate_dataset.py

    from base.preprocessing import GenericDataPreprocessing  # 导入基础数据预处理类
    from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
    from base.label_config import *  # 导入标签配置

    import os  # 导入os模块,用于系统级操作
    import scipy.io as sio  # 导入scipy.io模块,用于读取.mat文件

    import pandas as pd  # 导入pandas库,用于数据处理和分析
    import numpy as np  # 导入numpy库,用于数值计算

    import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


    class Preprocessing(GenericDataPreprocessing):
        def __init__(self, config):
            super().__init__(config)

        def generate_iterator(self):
            # 生成迭代器,返回按照文件名排序的文件路径列表
            path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
            iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
            return iterator

        def generate_per_trial_info_dict(self):
            # 生成每个试验的信息字典
            per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
            if os.path.isfile(per_trial_info_path):
                per_trial_info = load_pickle(per_trial_info_path)
            else:
                per_trial_info = {}
                pointer = 0

                sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
                all_continuous_labels = self.read_all_continuous_label()

                iterator = self.generate_iterator()

                for idx, file in enumerate(iterator):
                    kwargs = {}
                    this_trial = {}
                    print(file)

                    time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                    video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                    if video_trim_range is not None:
                        this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                    else:
                        this_trial['discard'] = 1
                        continue

                    this_trial['has_continuous_label'] = 0
                    session = int(file.split(os.sep)[-1])
                    subject_no, trial_no = session // 130 + 1, session % 130

                    if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                        this_trial['has_continuous_label'] = 1

                    this_trial['continuous_label'] = None
                    this_trial['annotated_index'] = None
                    annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                    if this_trial['has_continuous_label']:
                        raw_continuous_label = all_continuous_labels[pointer]
                        this_trial['continuous_label'] = raw_continuous_label
                        annotated_index = self.process_continuous_label(raw_continuous_label)
                        this_trial['annotated_index'] = annotated_index
                        pointer += 1

                    this_trial['has_eeg'] = 1
                    eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                    if len(eeg_path) == 1:
                        this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                    else:
                        this_trial['eeg_path'] = None
                        this_trial['has_eeg'] = 0

                    this_trial['audio_path'] = ""

                    this_trial['subject_no'] = subject_no
                    this_trial['trial_no'] = trial_no
                    this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                    this_trial['target_fps'] = 64

                    kwargs['feature'] = "video"
                    kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                    this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                    this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                    per_trial_info[idx] = this_trial

            ensure_dir(per_trial_info_path)
            save_to_pickle(per_trial_info_path, per_trial_info)
            self.per_trial_info = per_trial_info

        def generate_dataset_info(self):
            # 生成数据集信息
            class_label = {}
            for idx, record in self.per_trial_info.items():
                self.dataset_info['trial'].append(record['processing_record']['trial'])
                self.dataset_info['trial_no'].append(record['trial_no'])
                self.dataset_info['subject_no'].append(record['subject_no'])
                self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
                self.dataset_info['has_eeg'].append(record['has_eeg'])

                if record['has_continuous_label']:
                    self.dataset_info['length'].append(len(record['continuous_label']))
                else:
                    self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

                if self.config['extract_class_label']:
                    class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

            self.dataset_info['multiplier'] = self.config['multiplier']
            self.dataset_info['data_folder'] = self.config['npy_folder']

            path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
            save_to_pickle(path, self.dataset_info)

            if self.config['extract_class_label']:
                path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
                save_to_pickle(path, class_label)

        def extract_class_label_fn(self, record):
            # 提取类别标签
            class_label = {}
            if record['has_eeg']:
                xml_file = et.parse(record['class_label']).getroot()
                felt_emotion = xml_file.find('.').attrib['feltEmo']
                felt_arousal = xml_file.find('.').attrib['feltArsl']
                felt_valence = xml_file.find('.').attrib['feltVlnc']

                arousal = 0 if float(felt_arousal) <= 5 else 1
                valence = 0 if float(felt_valence) <= 5 else 1

                class_label = {
                    "Arousal": arousal,
                    "Valence": valence,
                    "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                    "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
                }

            return class_label

        def extract_continuous_label_fn(self, idx, npy_folder):
            # 提取连续标签
            if self.per_trial_info[idx]["has_continuous_label"]:
                raw_continuous_label = self.per_trial_info[idx]['continuous_label']

                if self.config['save_npy']:
                    filename = os.path.join(npy_folder, "continuous_label.npy")
                    if not os.path.isfile(filename):
                        ensure_dir(filename)
                        np.save(filename, raw_continuous_label)

        def load_continuous_label(self, path, **kwargs):
            # 加载连续标签
            cols = [emotion.lower() for emotion in self.config['emotion_list']]

            if os.path.isfile(path):
                continuous_label = pd.read_csv(path, sep=";",
                                               skipinitialspace=True, usecols=cols,
                                               index_col=False).values.squeeze()
            else:
                continuous_label = 0

            return continuous_label

        def get_annotated_index(self, annotated_index, **kwargs):
            # 获取标注索引
            feature = kwargs['feature']
            multiplier = self.config['multiplier'][feature]

            if kwargs['has_continuous_label']:
                annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
            else:
                pass

            return annotated_index

        def get_sub_trial_info_for_continuously_labeled(self):
            # 获取连续标签的子试验信息
            label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
            mat_content = sio.loadmat(label_file)
            sub_trial_having_continuous_label = mat_content['trials_included']

            return sub_trial_having_continuous_label

        @staticmethod
        def read_start_end_from_mahnob_tsv(tsv_file):
            # 从Mahnob的tsv文件中读取起始和结束时间
            if os.path.isfile(tsv_file):
                data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
                end = data[data['Event'] == 'MovieEnd'].index[0]
                start_end = [(0, end)]
            else:
                start_end = None
            return start_end

        def read_all_continuous_label(self):
            # 读取所有连续标签
            label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
            mat_content = sio.loadmat(label_file)
            annotation_cell = np.squeeze(mat_content['labels'])

            label_list = []
            for index in range(len(annotation_cell)):
                label_list.append(annotation_cell[index].T)
            return label_list

        @staticmethod
        def init_dataset_info():
            # 初始化数据集信息
            dataset_info = {
                "trial": [],
                "subject_no": [],
                "trial_no": [],
                "length": [],
                "has_continuous_label": [],
                "has_eeg": [],
            }
            return dataset_info


    if __name__ == "__main__":
        from configs import config

        pre = Preprocessing(config)
        pre.generate_per_trial_info_dict()
        pre.prepare_data()

    这段代码定义了一个名为Preprocessing的类,继承自GenericDataPreprocessing类,用于数据预处理。它包含了一些方法和函数,用于生成每个试验的信息字典、生成数据集信息、提取类别标签、提取连续标签等操作。在if __name__ == "__main__":部分,创建了Preprocessing对象,并调用了相关方法进行数据预处理。

    main.py

    from base.preprocessing import GenericDataPreprocessing  # 导入自定义的GenericDataPreprocessing类
    from base.utils import expand_index_by_multiplier, load_pickle, save_to_pickle, get_filename_from_a_folder_given_extension, ensure_dir  # 导入一些辅助函数和工具
    from base.label_config import *  # 导入标签配置

    import os  # 导入os模块,用于文件和目录操作
    import scipy.io as sio  # 导入scipy.io模块,用于读取MATLAB文件

    import pandas as pd  # 导入pandas库,用于数据处理
    import numpy as np  # 导入numpy库,用于数值计算

    import xml.etree.ElementTree as et  # 导入xml.etree.ElementTree模块,用于解析XML文件


    class Preprocessing(GenericDataPreprocessing):
        def __init__(self, config):
            super().__init__(config)

        def generate_iterator(self):
            path = os.path.join(self.config['root_directory'], self.config['raw_data_folder'])
            iterator = [os.path.join(path, file) for file in sorted(os.listdir(path), key=float)]
            return iterator

        def generate_per_trial_info_dict(self):
            # 生成每个试验的信息字典

            per_trial_info_path = os.path.join(self.config['output_root_directory'], "processing_records.pkl")
            if os.path.isfile(per_trial_info_path):
                per_trial_info = load_pickle(per_trial_info_path)
            else:
                per_trial_info = {}
                pointer = 0

                sub_trial_having_continuous_label = self.get_sub_trial_info_for_continuously_labeled()
                all_continuous_labels = self.read_all_continuous_label()

                iterator = self.generate_iterator()

                for idx, file in enumerate(iterator):
                    kwargs = {}
                    this_trial = {}
                    print(file)

                    time_stamp_file = get_filename_from_a_folder_given_extension(file, "tsv", "All-Data")[0]
                    video_trim_range = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                    if video_trim_range is not None:
                        this_trial['video_trim_range'] = self.read_start_end_from_mahnob_tsv(time_stamp_file)
                    else:
                        this_trial['discard'] = 1
                        continue

                    this_trial['has_continuous_label'] = 0
                    session = int(file.split(os.sep)[-1])
                    subject_no, trial_no = session // 130 + 1, session % 130

                    if subject_no == sub_trial_having_continuous_label[pointer][0] and trial_no == sub_trial_having_continuous_label[pointer][1]:
                        this_trial['has_continuous_label'] = 1

                    this_trial['continuous_label'] = None
                    this_trial['annotated_index'] = None
                    annotated_index = np.arange(this_trial['video_trim_range'][0][1])
                    if this_trial['has_continuous_label']:
                        raw_continuous_label = all_continuous_labels[pointer]
                        this_trial['continuous_label'] = raw_continuous_label
                        annotated_index = self.process_continuous_label(raw_continuous_label)
                        this_trial['annotated_index'] = annotated_index
                        pointer += 1

                    this_trial['has_eeg'] =  1
                    eeg_path = get_filename_from_a_folder_given_extension(file, "bdf")
                    if len(eeg_path) == 1:
                        this_trial['eeg_path'] = eeg_path[0].split(os.sep)
                    else:
                        this_trial['eeg_path'] = None
                        this_trial['has_eeg'] = 0

                    this_trial['audio_path'] = ""

                    this_trial['subject_no'] = subject_no
                    this_trial['trial_no'] = trial_no
                    this_trial['trial'] = "P{}-T{}".format(str(subject_no), str(trial_no))

                    this_trial['target_fps'] = 64

                    kwargs['feature'] = "video"
                    kwargs['has_continuous_label'] = this_trial['has_continuous_label']
                    this_trial['video_annotated_index'] = self.get_annotated_index(annotated_index, **kwargs)

                    this_trial['class_label'] = get_filename_from_a_folder_given_extension(file, "xml")[0]
                    per_trial_info[idx] = this_trial

            ensure_dir(per_trial_info_path)
            save_to_pickle(per_trial_info_path, per_trial_info)
            self.per_trial_info = per_trial_info

        def generate_dataset_info(self):
            # 生成数据集信息

            class_label = {}
            for idx, record in self.per_trial_info.items():
                self.dataset_info['trial'].append(record['processing_record']['trial'])
                self.dataset_info['trial_no'].append(record['trial_no'])
                self.dataset_info['subject_no'].append(record['subject_no'])
                self.dataset_info['has_continuous_label'].append(record['has_continuous_label'])
                self.dataset_info['has_eeg'].append(record['has_eeg'])

                if record['has_continuous_label']:
                    self.dataset_info['length'].append(len(record['continuous_label']))
                else:
                    self.dataset_info['length'].append(len(record['video_annotated_index']) // 16)

                if self.config['extract_class_label']:
                    class_label.update({record['processing_record']['trial']: self.extract_class_label_fn(record)})

            self.dataset_info['multiplier'] = self.config['multiplier']
            self.dataset_info['data_folder'] = self.config['npy_folder']

            path = os.path.join(self.config['output_root_directory'], 'dataset_info.pkl')
            save_to_pickle(path, self.dataset_info)

            if self.config['extract_class_label']:
                path = os.path.join(self.config['output_root_directory'], 'class_label.pkl')
                save_to_pickle(path, class_label)

        def extract_class_label_fn(self, record):
            # 提取类别标签的函数

            class_label = {}
            if record['has_eeg']:
                xml_file = et.parse(record['class_label']).getroot()
                felt_emotion = xml_file.find('.').attrib['feltEmo']
                felt_arousal = xml_file.find('.').attrib['feltArsl']
                felt_valence = xml_file.find('.').attrib['feltVlnc']

                arousal = 0 if float(felt_arousal) <= 5 else 1
                valence = 0 if float(felt_valence) <= 5 else 1

                class_label = {
                    "Arousal": arousal,
                    "Valence": valence,
                    "Arousal_3cls": arousal_class_to_number[emotion_tag_to_arousal_class[number_to_emotion_tag_dict[felt_emotion]]],
                    "Valence_3cls": valence_class_to_number[emotion_tag_to_valence_class[number_to_emotion_tag_dict[felt_emotion]]]
                }

            return class_label

        def extract_continuous_label_fn(self, idx, npy_folder):
            # 提取连续标签的函数

            if self.per_trial_info[idx]["has_continuous_label"]:
                raw_continuous_label = self.per_trial_info[idx]['continuous_label']

                if self.config['save_npy']:
                    filename = os.path.join(npy_folder, "continuous_label.npy")
                    if not os.path.isfile(filename):
                        ensure_dir(filename)
                        np.save(filename, raw_continuous_label)

        def load_continuous_label(self, path, **kwargs):
            # 加载连续标签

            cols = [emotion.lower() for emotion in self.config['emotion_list']]

            if os.path.isfile(path):
                continuous_label = pd.read_csv(path, sep=";",
                                               skipinitialspace=True, usecols=cols,
                                               index_col=False).values.squeeze()
            else:
                continuous_label = 0

            return continuous_label

        def get_annotated_index(self, annotated_index, **kwargs):
            # 获取标注索引

            feature = kwargs['feature']
            multiplier = self.config['multiplier'][feature]

            if kwargs['has_continuous_label']:
                annotated_index = expand_index_by_multiplier(annotated_index, multiplier)
            else:
                pass

            return annotated_index

        def get_sub_trial_info_for_continuously_labeled(self):
            # 获取具有连续标签的子试验信息

            label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
            mat_content = sio.loadmat(label_file)
            sub_trial_having_continuous_label = mat_content['trials_included']

            return sub_trial_having_continuous_label

        @staticmethod
        def read_start_end_from_mahnob_tsv(tsv_file):
            # 从Mahnob TSV文件中读取起始和结束时间

            if os.path.isfile(tsv_file):
                data = pd.read_csv(tsv_file, sep='\t', skiprows=23)
                end = data[data['Event'] == 'MovieEnd'].index[0]
                start_end = [(0, end)]
            else:
                start_end = None
            return start_end

        def read_all_continuous_label(self):
            # 读取所有连续标签

            label_file = os.path.join(self.config['root_directory'], "lable_continous_Mahnob.mat")
            mat_content = sio.loadmat(label_file)
            annotation_cell = np.squeeze(mat_content['labels'])

            label_list = []
            for index in range(len(annotation_cell)):
                label_list.append(annotation_cell[index].T)
            return label_list

        @staticmethod
        def init_dataset_info():
            # 初始化数据集信息

            dataset_info = {
                "trial": [],
                "subject_no": [],
                "trial_no": [],
                "length": [],
                "has_continuous_label": [],
                "has_eeg": [],
            }
            return dataset_info


    if __name__ == "__main__":
        from configs import config  # 导入配置文件

        pre = Preprocessing(config)  # 创建Preprocessing对象,传入配置文件
        pre.generate_per_trial_info_dict()  # 生成每个试验的信息字典
        pre.prepare_data()  # 准备数据
    这段代码是一个数据预处理的类Preprocessing,继承自GenericDataPreprocessing。它包含了一些方法用于生成每个试验的信息字典、生成数据集信息、提取类别标签和连续标签等操作。在__main__函数中,创建了一个Preprocessing对象,并调用了相关方法进行数据预处理。

    加入其他数据集中

    ```python
    from model import MASA_TCN  # 从model模块中导入MASA_TCN模型

    data = torch.randn(1, 1, 192, 96)  # 生成一个随机张量作为输入数据,形状为(batch_size=1, cnn_channel=1, EEG_channel*feature=32*6, data_sequence=96)

    # 对于回归任务,输出形状为(batch_size, data_sequence, 1)。
    net = MASA_TCN(
            cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
            cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
            cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
            num_eeg_chan=32,  # EEG通道数
            freq=6,  # 特征频率
            output_dim=1,  # 输出维度
            early_fusion=True,  # 是否使用早期融合
            model_type='reg')  # 模型类型为回归
    preds = net(data)  # 对输入数据进行预测

    # 对于分类任务,输出形状为(batch_size, num_classes)。注意:output_dim应该是类别的数量。
    net = MASA_TCN(
            cnn1d_channels=[128, 128, 128],  # 1维卷积层的通道数列表
            cnn1d_kernel_size=[3, 5, 15],  # 1维卷积层的核大小列表
            cnn1d_dropout_rate=0.1,  # 1维卷积层的dropout率
            num_eeg_chan=32,  # EEG通道数
            freq=6,  # 特征频率
            output_dim=2,  # 输出维度
            early_fusion=True,  # 是否使用早期融合
            model_type='cls')  # 模型类型为分类
    preds = net(data)  # 对输入数据进行预测
    ```

    这段代码首先导入了MASA_TCN模型,然后创建了一个随机输入数据,并使用MASA_TCN模型进行了回归和分类任务的预测。注释已经在代码中添加了。

  • 相关阅读:
    android 性能优化
    macos知名的清理软件 cleanmymac和腾讯柠檬哪个好 cleanmymacx有必要买吗
    3.获取元素
    GIC/ITS代码分析(6)中断处理
    大数据平台搭建及集群规划
    FineReport中字符串常用处理函数
    Spring的前置知识
    Elasticsearch搜索引擎
    VS2010 Windows API 串口编程 (二)
    CentOS7中安装PostgreSQL
  • 原文地址:https://blog.csdn.net/fdgdf5535/article/details/139728182