• 深度学习-nlp系列(1)文本分类(TextCNN)pytorch


    TextCNN 模型是用来对文本进行分类。

    TextCNN 模型结构图

    TextCNN模型结构比较简单,其论文中整个模型的结构如下图所示:

     图1 Text CNN 模型结构图1

    对于论文中的模型图可能会看不懂,我们会对下面这张原理图进行讲解:

     图1 Text CNN 模型结构图2

    输入一句话:I like this movie very mush!,将其向量化,得到维度为5的矩阵,其 shape 为[1,7,5]。

    将其送入模型,先经过3个卷积,卷积核大小分别为(2,5),(3,5),(4,5)。得到的 feature_map 的 shape 为 [1, 6 ],[1, 5],[1,4]。

    将得到的 feature_map 经过最大池化,得到 feature_map 的 shape 为 [1, 2 ],[1, 2],[1,2]。

    将池化后的 feature_map 进行拼接,得到的 shape 为 [1,6],最后将其分为 2 分类。

    注:每个 shape 中的第一维度均为 batch_size。这里是以论文为主,所有为 1 ,实际不为 1。

    数据集处理

    数据集使用THUCNews中的train.txt、test.txt、dev.txt,为十分类问题。其中训练集一共有 180000 条,验证集一共有 10000 条,测试集一共有 10000 条。其类别为 finance、realty、stocks、education、science、society、politics、sports、game、entertainment 这 十个类别。

    对于输入数据,我们需要将文本转换成 embedding,即向量化。并且所有数据的长度是需要一致的。

    读取所有数据

    将数据处理成模型输入的格式 

    1. text = self.all_text[index][:self.max_len]
    2. label = int(self.all_label[index])
    3. text_idx = [self.word_2_index.get(i, 1) for i in text]
    4. text_idx = text_idx + [0] * (self.max_len - len(text_idx))
    5. text_idx = torch.tensor(text_idx).unsqueeze(dim=0)
    6. return text_idx, label

    模型准备

    1. class Block(nn.Module):
    2. def __init__(self, kernel_s, embeddin_num, max_len, hidden_num):
    3. super().__init__()
    4. # shape [batch * in_channel * max_len * emb_num]
    5. self.cnn = nn.Conv2d(in_channels=1, out_channels=hidden_num, kernel_size=(kernel_s, embeddin_num))
    6. self.act = nn.ReLU()
    7. self.mxp = nn.MaxPool1d(kernel_size=(max_len - kernel_s + 1))
    8. def forward(self, batch_emb): # shape [batch * in_channel * max_len * emb_num]
    9. c = self.cnn(batch_emb)
    10. a = self.act(c)
    11. a = a.squeeze(dim=-1)
    12. m = self.mxp(a)
    13. m = m.squeeze(dim=-1)
    14. return m
    15. class TextCNNModel(nn.Module):
    16. def __init__(self, emb_matrix, max_len, class_num, hidden_num):
    17. super().__init__()
    18. self.emb_num = emb_matrix.weight.shape[1]
    19. self.block1 = Block(2, self.emb_num, max_len, hidden_num)
    20. self.block2 = Block(3, self.emb_num, max_len, hidden_num)
    21. self.block3 = Block(4, self.emb_num, max_len, hidden_num)
    22. self.emb_matrix = emb_matrix
    23. self.classifier = nn.Linear(hidden_num * 3, class_num) # 2 * 3
    24. self.loss_fun = nn.CrossEntropyLoss()
    25. def forward(self, batch_idx): # shape torch.Size([batch_size, 1, max_len])
    26. batch_emb = self.emb_matrix(batch_idx) # shape torch.Size([batch_size, 1, max_len, embedding])
    27. b1_result = self.block1(batch_emb) # shape torch.Size([batch_size, 2])
    28. b2_result = self.block2(batch_emb) # shape torch.Size([batch_size, 2])
    29. b3_result = self.block3(batch_emb) # shape torch.Size([batch_size, 2])
    30. # 拼接
    31. feature = torch.cat([b1_result, b2_result, b3_result], dim=1) # shape torch.Size([batch_size, 6])
    32. pre = self.classifier(feature) # shape torch.Size([batch_size, class_num])
    33. return pre

    模型训练

    1. for epoch in range(args.epochs):
    2. model.train()
    3. loss_sum, count = 0, 0
    4. for batch_index, (batch_text, batch_label) in enumerate(train_loader):
    5. batch_text, batch_label = batch_text.to(device), batch_label.to(device)
    6. pred = model(batch_text)
    7. loss = loss_fn(pred, batch_label)
    8. opt.zero_grad()
    9. loss.backward()
    10. opt.step()

    训练结果

    模型预测

    源码获取

    TextCNN 文本分类

  • 相关阅读:
    DASCTF X CBCTF 2023|无畏者先行
    拼多多API接口大全
    web前端三大主流框架指的是什么
    Python基础 – 使用别人代码的模块机制
    【记录】Discuz!论坛防灌水防注册机,清理垃圾会员
    论文阅读 | RAFT: Recurrent All-Pairs Field Transforms for Optical Flow
    如何获取GC(垃圾回收器)的STW(暂停)时间?
    MySQL 主从复制
    【八股】计算机网络-HTTP和HTTPS的区别、HTTPS加密传输原理
    [idekCTF 2022]Paywall - LFI+伪协议+filter_chain
  • 原文地址:https://blog.csdn.net/qq_48764574/article/details/125757595