• 深度学习笔记(4)——TextCNN、BiLSTM实现情感分类(weibo100k数据集)


    0 前言

    使用数据集:微博数据集,共有约12万条数据,标签数为2。
    配置环境:Rtx3060 Laptop/AutoDL

    1 数据准备

    1.1 路径、常量、超参数

    # 路径
    DATASET_PATH = '../data/weibo/weibo_senti_100k.csv'
    USER_DICT = '../data/weibo/user_dict.txt'
    
    # 常量
    DEVICE = 'cuda:0' if torch.cuda.is_available() else "cpu"
    loss_func = nn.CrossEntropyLoss()
    loss_list, accuracy_list = [], []
    
    # 超参数
    MAX_LEN = 200  # 语句长度
    BATCH_SIZE = 128  # 批次大小
    EMBEDDING_SIZE = 600  # embedding层大小
    WINDOWS_SIZE = (2, 3, 4)  # 窗口大小
    FEATURE_SIZE = 200  # 特征大小
    N_CLASSES = 2  # 类别数
    EPOCHS = 10  # 批次大小
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    1.2 加载数据集

    # 数据探索性分析
    def eda():
        # 加载数据
        dataset = pd.read_csv(DATASET_PATH)
        data1 = dataset['review'].iloc[:20000].values.tolist()
        data2 = dataset['review'].iloc[20000:40000].values.tolist()
        data3 = dataset['review'].iloc[40000:60000].values.tolist()
        data4 = dataset['review'].iloc[60000:80000].values.tolist()
        data5 = dataset['review'].iloc[80000:100000].values.tolist()
        data6 = dataset['review'].iloc[100000:].values.tolist()
        datas = [data1, data2, data3, data4, data5, data6]
        labels = dataset['label']
        print(labels.value_counts())  # 文本特征提取器
        return datas, labels
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    难点
    数据量过大、处理低频词耗时较长(算法时间复杂度为 O 2 O^2 O2
    原因
    去低频词时,双重for循环时间复杂度过高,文本量也大。
    解决方案

    1. 减少min_threshold(效果不明显)
    2. 数据分批处理(效果显著,本文使用)
    3. 多线程优化(暂时没学会,好像效果也不好)
    4. list转为numpy(效果不明显)
    5. numba加速(效果不明显,不清楚什么原因)

    2 文本清洗

    进行数据探索分析时,发现以下这些文本可以处理:
    例如,观察以下数据

    corpus = [
            '[鼓掌]//@权金城崔洪峰:扩散@权金城彭涌 @权金城-崔成哲 //@思想聚焦:转发微博',
            'UP!虽然你很不和谐//@风言疯语LaiN胖子:为啥你不关注别人,却要别人关注你?学名人啊?[嘻嘻] //@ponponxu:转发微博。',
            '[鼓掌]  //@金朝顺:帅哥美女如云~恭喜开课!@魏英俊 @solonso  //@洪璐葫芦:恭喜了!//@昆晏:转发微博 转发微博',
            '#轻松一刻# 笑成狗了!主人太有才了![哈哈] #哈哈#',
            '#约惠海航 圆梦飞翔#【惠享直减】购票购票购票,直减直减直减[打哈欠] #测试#这不冲突,也很科学,来海航官网购票,管够,管实惠【惠享直减】http://t.cn/zRpYB9r [嘻嘻] 每天500个名额,20元的直减,ok的赶快来[赞] 今天第二波15点开始~',
            '激动人心的时刻[心]#微动日照#传播大赛大奖ipad实图奉上!感谢@日照市旅游局官方微博 的好活动.东方太阳城给了我太多惊喜,美食霸占味蕾,美景俘获视觉[爱你]仙山兔耳鳗鱼香螺,故地重游也仍有遗憾.日照,美就一个字,我还会再来的@日照旅游王立新@日照旅游-日出先照当属日照@日照旅游咨询网@山海美景',
            '#昆航动态#2010年11月6日,在昆明市创业投资引导基金推介暨颁奖晚宴上,2010年11月昆明航空有限公司董事长王清民(图中左五)从昆明市委常委、副市长刘光溪手上接过#2010泛亚地区最具投资潜质十强企业#证书和奖杯。[鼓掌] 昆明航空成为500多家报名企业中唯一一家获奖的航空企业。[礼花] http://sinaurl.cn/h4QFmF',
            '回复@夜里梵高:君亭的家门向每个游子敞开!欢迎回家![鼓掌] //@夜里梵高:我想回家!哈哈哈[亲亲] //@杭州君亭湖滨酒店:君亭,你在杭州的另一个家',
            '... 。。。 !!! !!! ??? ??? ?! 。。。。 !!!!!!  '
        ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    1. 话题
    2. 转发微博
    3. 回复@usename
    4. @username(空格)
    5. 含有隐含意义的标点符号,例如。。。 ??? !!!(中文文化真是博大精深)
    6. 网址
    7. 时间
      通过设立正则表达式匹配实现替换这些这些对情感分析无关联的信息
      最后,清楚文本中的英文字符、中英文符号等
    import string
    import jieba
    import re
    from zhon import hanzi
    from tqdm import tqdm
    
    
    class WeiBoTextCleaner:
        def __init__(self, corpus):
            self.corpus = corpus
            self.new_corpus = []
    
        def extract_topic(self, sent):
            """
            提取话题
            #*# 
            【*】
            :return:
            """
            pattern1 = re.compile('【([^】]+)】')
            pattern2 = re.compile('#([^#]+)#')
            sent = re.sub(pattern1, '', sent)
            sent = re.sub(pattern2, '', sent)
            return sent
    
        def extract_forward(self, sent):
            """
            提取转发微博
            转发微博
            :return:
            """
            pattern = re.compile('转发微博')
            return re.sub(pattern, '', sent)
    
        def extract_reply(self, sent):
            """
            提取回复@username
            回复@username:
            :return:
            """
            pattern = re.compile('回复@[a-zA-Z\u4e00-\u9fa5_0-9-]+')
            return re.sub(pattern, '', sent)
    
        def extract_username(self, sent):
            """
            提取用户名
            @username(空格)
            :return:
            """
            pattern = re.compile('@[a-zA-Z\u4e00-\u9fa5_0-9-]+')
            return re.sub(pattern, '', sent)
    
        def extract_emotional_punctuation(self, sent):
            """
            提取含有隐含意义的标点符号
            ... !!! ??? ?!
            。。。 !!! ??? ?!
            :return:
            """
            pattern1 = re.compile('。{3,}')
            pattern2 = re.compile(r'\.{3,}')
            pattern3 = re.compile('!{3,}')
            pattern4 = re.compile('!{3,}')
            pattern5 = re.compile(r'\?{3,}')
            pattern6 = re.compile('?{3,}')
            pattern7 = re.compile(r'\?!')
            pattern8 = re.compile(r'?!')
            sent = re.sub(pattern1, '自定义一', sent)
            sent = re.sub(pattern2, '自定义一', sent)
            sent = re.sub(pattern3, '自定义二', sent)
            sent = re.sub(pattern4, '自定义二', sent)
            sent = re.sub(pattern5, '自定义三', sent)
            sent = re.sub(pattern6, '自定义三', sent)
            sent = re.sub(pattern7, '自定义四', sent)
            sent = re.sub(pattern8, '自定义四', sent)
            return sent
    
        def extract_weblink(self, sent):
            """
            提取网址
            http://*
            :return:
            """
            pattern = re.compile('http://[0-9a-zA-Z./]+')
            return re.sub(pattern, '', sent)
    
        def extract_time(self, sent):
            """
            提取时间
            *年*月*日
            *年*月
            *月*日
            :return:
            """
            pattern1 = re.compile(r'\d{4}年\d{1,2}月\d{1,2}日')
            pattern2 = re.compile(r'\d{4}年\d{1,2}月')
            pattern3 = re.compile(r'\d{1,2}月\d{1,2}日')
            sent = re.sub(pattern1, '', sent)
            sent = re.sub(pattern2, '', sent)
            sent = re.sub(pattern3, '', sent)
            return sent
    
        def clear_character(self, sent):
            """
            清楚无效字符
            :param sent:
            :return:
            """
            pattern1 = re.compile('[a-zA-Z0-9]')  # 英文字符和数字
            pattern2 = re.compile(r'[^\s1234567890::' + '\u4e00-\u9fa5]+')  # 表情和其他字符
            pattern3 = re.compile('[%s]+' % re.escape(string.punctuation + hanzi.punctuation))  # 标点符号
            sent = re.sub(pattern1, '', sent)
            sent = re.sub(pattern2, '', sent)
            sent = re.sub(pattern3, '', sent)
            sent = ''.join(sent.split())  # 去除空白
            return sent
    
        def execute(self):
            for sentence in tqdm(self.corpus):
                sentence = self.extract_forward(sentence)
                sentence = self.extract_reply(sentence)
                sentence = self.extract_username(sentence)
                sentence = self.extract_topic(sentence)
                sentence = self.extract_weblink(sentence)
                sentence = self.extract_time(sentence)
                sentence = self.extract_emotional_punctuation(sentence)
                sentence = self.clear_character(sentence)
                self.new_corpus.append(sentence)
            return self.new_corpus
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129

    3 分词

    # 去低频词
    def remove_words(corpus, delete_list):
        for seg_list in tqdm(corpus):
            for seg in seg_list:
                if seg in delete_list:
                    seg_list.remove(seg)
        return corpus
    
    
    # 分词
    def tokenizer(corpus, min_threshold, i):
        t1 = time.time()
        # 加载用户字典
        jieba.load_userdict(USER_DICT)
        corpus = list(map(jieba.lcut, corpus))
    
        # 去低频词
        print('去低频词')
        word_list = []
        for seg_list in tqdm(corpus):
            word_list.extend(seg_list)
        counter = Counter(word_list)
        delete_list = []  # 要去除的词
        for k, v in counter.items():
            if v < min_threshold:
                delete_list.append(k)
        print(f'词总数:{len(word_list)}')
        print(f'要去除低频词数量:{len(delete_list)}')
    
        corpus = remove_words(corpus, delete_list)
        print(len(corpus))
    
        print('序列化列表')
        with open(f'../data/weibo/corpus{i}.pkl', 'wb') as f:
            pickle.dump(corpus, f)
    
        t2 = time.time()
        print(f'共耗时{t2 - t1}秒')
    
    
    # 合并pkl
    def combine_pkl(paths: list):
        sentences = []
        for path in paths:
            with open(path, 'rb') as f:
                sentences.extend(pickle.load(f))
        return sentences
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    4 工具类、文本向量化

    工具类

    import torch
    from tqdm import tqdm
    from sklearn.metrics import accuracy_score
    
    
    # 生成word2index
    def compute_word2index(sentences, word2index):
        for sentences in sentences:
            for word in sentences:
                if word not in word2index:
                    word2index[word] = len(word2index)  # word2index存储的是索引
        return word2index
    
    
    # 生成sent2index
    def compute_sent2index(sentence, max_len, word2index):
        sent2index = [word2index.get(word, 0) for word in sentence]
        if len(sentence) < max_len:
            sent2index += (max_len - len(sentence)) * [0]
        else:
            sent2index = sentence[:max_len]
        return sent2index
    
    
    # 文本表示
    def text_embedding(sentences, max_len):
        # 生成词向量与句向量
        word2index = {"PAD": 0}
        word2index = compute_word2index(sentences, word2index)
        sent2indexs = []
        for sent in tqdm(sentences):
            sentence = compute_sent2index(sent, max_len, word2index)
            sent2indexs.append(sentence)
        return word2index, sent2indexs
    
    
    # 计算准确率
    def get_accuracy(model, datas, labels):
        out = torch.softmax(model(datas), dim=1, dtype=torch.float32)
        predictions = torch.max(input=out, dim=1)[1]  # 最大值的索引
        y_predict = predictions.to('cpu').data.numpy()
        y_true = labels.to('cpu').data.numpy()
        accuracy = accuracy_score(y_true, y_predict)  # 准确率
        return accuracy
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    5 模型构建

    TextCNN上一篇文章

    BiLSTM
    网络结构

    代码

    import torch
    from torch import nn, optim
    
    
    class BiLSTM(nn.Module):
        def __init__(self, num_embeddings, embedding_dim, hidden_size, num_layers, num_classes, device):
            super(BiLSTM, self).__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.device = device
            # 词嵌入层
            self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
            # LSTM
            self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
                                bidirectional=True)
            # Dropout层
            self.dropout = nn.Dropout(p=0.5)
            # 全连接层
            self.fc = nn.Linear(in_features=hidden_size * 2, out_features=num_classes)
    
        def forward(self, x):
            x = self.embed(x)  # [batch_size, max_len, 100]
            h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device)
            c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(self.device)
            out, (h_n, c_n) = self.lstm(x, (h0, c0))
            output_fw = h_n[-2, :, :]  # 正向最后一次输出
            output_bw = h_n[-1, :, :]  # 反向最后一词输出
            out = torch.concat([output_fw, output_bw], dim=1)  # [batch_size, hidden_size*2]
            # x = torch.softmax(x, dim=1)
    
            # x = self.fc(out[:, -1, :])
            x = self.fc(out)
            return x
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    6 评估

    # 训练
    def train(model, dataloaer, optimizer, epoch):
        model.train()  # 模型训练
        for i, (datas, labels) in enumerate(dataloaer):
            # 设备转换
            datas = datas.to(DEVICE)
            labels = labels.to(DEVICE)
            # 计算结果
            out = model(datas)
            # 计算损失值
            loss = loss_func(out, labels)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 梯度更新
            optimizer.step()
            # 打印损失值
            if i % 300 == 0:
                loss_list.append(loss.item())
                accuracy = get_accuracy(model, datas, labels)
                accuracy_list.append(accuracy)
                print('Train Epoch:%d Loss:%0.6f Accuracy:%0.6f' % (epoch, loss.item(), accuracy))
    
    
    # 绘制曲线
    def plot_curve(epochs, accuracy_list, loss_list, model_name):
        # 计算平均值
        accuracy_array = np.array(accuracy_list).reshape(epochs, -1)
        accuracy_array = np.mean(accuracy_array, axis=1)
        loss_array = np.array(loss_list).reshape(epochs, -1)
        loss_array = np.mean(loss_array, axis=1)
    
        # 绘制Loss曲线
        plt.rcParams['figure.figsize'] = (16, 8)
        plt.subplots(1, 2)
        plt.subplot(1, 2, 1)
        plt.plot(range(epochs), loss_array)
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.title('Loss Curve')
        plt.subplot(1, 2, 2)
        plt.plot(range(epochs), accuracy_array)
        plt.xlabel('epoch')
        plt.ylabel('accuracy')
        plt.title('Accuracy Cure')
        plt.savefig(f'../figure/weibo_{model_name}.png')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    TextCNN
    在这里插入图片描述

    BiLSTM
    在这里插入图片描述
    两种模型仅训练了几个epoch,在训练集上的准确率均达到了98.5%

    7 总览

    def execute():
        # # EDA
        datas, labels = eda()
        #  数据清洗
        for i, data in enumerate(datas):
            print(f'数据清洗 第{i+1}份')
            cleaner = WeiBoTextCleaner(data)
            corpus = cleaner.execute()
            # 分词
            tokenizer(corpus, 25, i + 1)
        # 合并pickle
        paths = [f'../data/weibo/corpus{i}.pkl' for i in range(1, 7)]
        sentences = combine_pkl(paths)
        print(len(sentences))
        # 文本表示
        print('文本表示')
        word2index, sent2index = text_embedding(sentences, MAX_LEN)
    
        # 装载数据集
        train_dataset = MyDataSet(sent2index, labels)
        dataloader_train = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
        # 构建模型
        vocab_size = len(word2index)
        # TextCNN
        # model = TextCNN(vocab_size=vocab_size, embedding_dim=EMBEDDING_SIZE, windows_size=WINDOWS_SIZE,
        #                 max_len=MAX_LEN, feature_size=FEATURE_SIZE, n_class=N_CLASSES).to(DEVICE)
        # optimizer = optim.Adam(model.parameters(), lr=0.001)
        print('GPU_Allocated:%d' % torch.cuda.memory_allocated())
    
        model = BiLSTM(num_embeddings=vocab_size, embedding_dim=MAX_LEN, hidden_size=MAX_LEN,
                       num_layers=2, num_classes=N_CLASSES, device=DEVICE).to(DEVICE)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        print('GPU_Allocated:%d' % torch.cuda.memory_allocated())
    
        # 模型训练
        for i in range(EPOCHS):
            print(f'{i+1}/{EPOCHS}')
            train(model, dataloader_train, optimizer, i+1)
        # 模型保存
        torch.save(model.state_dict(), '../model/bilstem_weibo.pkl')
        # 绘制曲线
        plot_curve(EPOCHS, accuracy_list, loss_list, 'BiLSTM')
    
    
    if __name__ == '__main__':
        execute()
        # test_model(64826)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    8 实时测试

    # 实时检测
    def test_model(vocab_size):
        # 加载模型
        model = TextCNN(vocab_size=vocab_size, embedding_dim=EMBEDDING_SIZE, windows_size=WINDOWS_SIZE,
                        max_len=MAX_LEN, feature_size=FEATURE_SIZE, n_class=N_CLASSES).to(DEVICE)
        model.load_state_dict(torch.load('../model/textcnn_weibo.pkl'))
        warnings.filterwarnings(action='ignore')
    
        while True:
            sentence = input("检测您的微博")
            data = [sentence]
    
            # 处理
            cleaner = WeiBoTextCleaner(data)
            corpus = cleaner.execute()
            jieba.load_userdict(USER_DICT)
            corpus = list(map(jieba.lcut, corpus))
            word2index, sent2index = text_embedding(corpus, MAX_LEN)
            datas = sent2index
            datas = torch.LongTensor(datas).to(DEVICE)
    
            # 预测
            out = model(datas)
            out = torch.softmax(out, dim=1, dtype=torch.float32)
            predictions = torch.max(input=out, dim=1)[1]
            y_predict = predictions.to('cpu').data.numpy()
    
            if y_predict[0] == 1:
                print('积极')
            else:
                print('消极')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    在这里插入图片描述
    在这里插入图片描述

  • 相关阅读:
    鸿蒙文件操作事前准备
    netfilter编程实例——一个简单的防火墙
    CDH Kerberos启动后hue报错Couldn‘t renew kerberos ticket
    mmap()
    git使用进阶(二)
    【从零开始学习 SystemVerilog】3.8、SystemVerilog 控制流—— Tasks(任务)
    QT QTableWidget 表格列置顶需求的思路和代码
    「Python循环结构」利用for循环输出信息和求阶乘
    Windows XP迎来20岁生日,仍有上百万用户坚守
    云小课|使用SQL加密函数实现数据列的加解密
  • 原文地址:https://blog.csdn.net/m0_46275020/article/details/126474891