• 深度学习-nlp系列(2)文本分类(Bert)pytorch


    对于 Bert 来说,用于文本分类是最常见的,并且准确率也很高。本文将会对 bert 用于文本分类来做详细的介绍。

     

    预训练模型

    对于不同的数据,需要导入不同的预训练模型。

    预训练模型下载地址:Models - Hugging Face

    本文使用的是中文数据集,因此需要选择中文的预训练模型:bert-base-chinese at main

    Bert 模型主要结构

    BertModel 主要为 transformer encoder 结构,包含三个部分:

    1. embeddings,即BertEmbeddings类的实体,对应词嵌入;
    2. encoder,即BertEncoder类的实体;
    3. pooler,即BertPooler类的实体,这一部分是可选的。

    注意:BertModel 也可以配置为 Decoder

     图1 bert 模型初始化/结构 

    Bert文本分类模型常见做法为将bert最后一层输出的第一个token位置(CLS位置)当作句子的表示,后接全连接层进行分类。

    Bert 模型输入

    Bert 模型可以用于不同的场景,在文本分类,实体识别等场景的输入是不同的。 

    对于文本分类,其最主要的有两个参数:input_ids,attention_mask

     图2 bert 模型输入

    input_ids:经过 tokenizer 分词后的 subword 对应的下标列表;

    attention_mask:在 self-attention 过程中,这一块 mask 用于标记 subword 所处句子和 padding 的区别,将 padding 部分填充为 0;

    Bert 模型输出 

    该模型的输出也是有多个,但是只有一个是用于文本分类的

    bert的输出结果的其中四个维度:

    1、last_hidden_state:shape是(batch_size, sequence_length, hidden_size),hidden_size=768,它是模型最后一层输出的隐藏状态。

    2、pooler_output:shape是(batch_size, hidden_size),在通过用于辅助预训练任务的层进行进一步处理后,序列的第一个token(classification token)的最后一层的隐藏状态。例如。对于 BERT 系列模型,这会在通过线性层和 tanh 激活函数处理后返回分类标记。线性层权重在预训练期间从下一个句子预测(分类)目标进行训练。

    3、hidden_states:shape是(batch_size, sequence_length, hidden_size),这是输出的一个可选项,如果输出,

    需要指定`output_hidden_states=True` is passed or when `config.output_hidden_states=True`

    它的第一个元素是embedding,其余元素是各层的输出。每层输出的模型隐藏状态加上可选的初始嵌入输出。

    4、attentions:shape是(batch_size,num_heads,sequence_length,sequence_length),这也是输出的一个可选项,如果输出,

    需要指定`output_attentions=True` is passed or when `config.output_attentions=True`

    它的元素是每一层的注意力权重(softmax),用于计算self-attention heads的加权平均值。

    我们是微调模式,需要获取bert最后一个隐藏层的输出输入到下一个全连接层,所以取第一个维度,也就是hiden_outputs.pooler_output

     图3 bert 模型的输出

     数据处理

    数据格式

    读取所有数据

    1. # 读取文件
    2. all_data = open(file, "r", encoding="utf-8").read().split("\n")
    3. # 得到所有文本、所有标签、句子的最大长度
    4. texts, labels, max_length = [], [], []
    5. for data in all_data:
    6. if data:
    7. text, label = data.split("\t")
    8. max_length.append(len(text))
    9. texts.append(text)
    10. labels.append(label)

     将数据处理成模型输入的格式

    1. # 取出一条数据并截断长度
    2. text = self.all_text[index][:self.max_len]
    3. label = self.all_label[index]
    4. # 分词
    5. text_id = self.tokenizer.tokenize(text)
    6. # 加上起始标志
    7. text_id = ["[CLS]"] + text_id
    8. # 编码
    9. token_id = self.tokenizer.convert_tokens_to_ids(text_id)
    10. # 掩码 -》
    11. mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))
    12. # 编码后 -》长度一致
    13. token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))
    14. # str -》 int
    15. label = int(label)
    16. # 转化成tensor
    17. token_ids = torch.tensor(token_ids)
    18. mask = torch.tensor(mask)
    19. label = torch.tensor(label)
    20. return (token_ids, mask), label

    模型准备

    1. class MyModel(nn.Module):
    2. def __init__(self):
    3. super(MyModel, self).__init__()
    4. self.args = parsers()
    5. self.device = "cuda:0" if self.args.device else "cpu"
    6. # 加载 bert 中文预训练模型
    7. self.bert = BertModel.from_pretrained(self.args.bert_pred)
    8. # 让 bert 模型进行微调(参数在训练过程中变化)
    9. for param in self.bert.parameters():
    10. param.requires_grad = True
    11. # 全连接层
    12. self.linear = nn.Linear(self.args.num_filters, self.args.class_num)
    13. def forward(self, x):
    14. input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)
    15. hidden_out = self.bert(input_ids, attention_mask=attention_mask,
    16. output_all_encoded_layers=False) # 控制是否输出所有encoder层的结果
    17. # shape (batch_size, hidden_size)
    18. pred = self.linear(hidden_out.pooler_output)
    19. # 返回预测结果
    20. return pred

    模型训练

    1. for epoch in range(args.epochs):
    2. loss_sum, count = 0, 0
    3. model.train()
    4. for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):
    5. batch_label = batch_label.to(device)
    6. pred = model(batch_text)
    7. loss = loss_fn(pred, batch_label)
    8. opt.zero_grad()
    9. loss.backward()
    10. opt.step()
    11. loss_sum += loss
    12. count += 1
    13. # 打印内容
    14. if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:
    15. msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
    16. logging.info(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
    17. loss_sum, count = 0.0, 0
    18. if batch_index % 1000 == 999:
    19. msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"
    20. logging.info(msg.format(epoch + 1, batch_index + 1, loss_sum / count))
    21. loss_sum, count = 0.0, 0

    训练结果

    模型预测  

    源码获取

    bert 文本分类

  • 相关阅读:
    K8S使用开源CEPH作为后端StorageClass
    【Java】使用`LinkedList`类来实现一个队列,并通过继承`AbstractQueue`或者实现`Queue`接口来实现自定义队列
    【Loadrunner】学习loadrunner——性能测试基础篇VUG的使用(二)
    ios-关联对象
    Springboot开发系统记录操作日志
    springcloud+nacos+gateway+oauth2小聚会
    高效巧用这19条MySQL优化
    在Vs-Code中配置“@”路径提示的插件
    面试网络-0x01 http中的GET和POST区别?
    《重构:改善既有代码的设计》读书笔记(上)
  • 原文地址:https://blog.csdn.net/qq_48764574/article/details/126068667