• 高版本transformers-4.24中的坑


    最近遇到一个很奇怪的BUG,好早之前写的一个Bert文本分类模型,拿给别人用的时候,发现不灵了,原本90多的acc,什么都没修改,再测一次发现只剩30多了,检查了一番之后,很快我发现他的transformers版本是4.24,而我一直用的是4.9,没有更新。

    于是我试着分析问题出在哪里,然后就遇到了这个坑。首先这是我模型的基础结构,很简单,就是一个Encoder模型加一层分类器:

    class BertClassifier(torch.nn.Module):
        def __init__(self, bert_model, num_classes):
            super(BertClassifier, self).__init__()
            self.bert = bert_model
            self.dropout = torch.nn.Dropout(0.2)
            self.dense = torch.nn.Linear(768, num_classes)
            
        def forward(
            self,
            input_ids=None,
            token_type_ids=None,
            attention_mask=None,
            labels=None,
        ):
            bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)
            # print(list(self.bert.encoder.layer[0].attention.self.query.parameters()))
            # print(bert_out)
            sequence_output = bert_out.last_hidden_state
            print(sequence_output)
            sequence_output = self.dropout(sequence_output)
            pool_output = torch.mean(sequence_output, axis=1)
            
            logits = self.dense(pool_output)
            # print(logits)
            
            loss = None
            loss_fct = torch.nn.CrossEntropyLoss()
            
            if labels is not None:
                # labels = label.long()
                loss = loss_fct(logits, labels.view(-1))
            
            return loss if loss is not None else logits
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    为了分析问题出在哪里,我把类里的代码全都拿出来,逐行运行,发现最终的logits和正确的logits(在4.9版本的环境里执行的结果)是一致的,这就很奇怪了,但是我实例化模型,再用模型forward出来的结果却是错误的:

    # 这个结果计算出来是对的
    sequence_output = bert_cls_model.bert(**inputs).last_hidden_state
    sequence_output = bert_cls_model.dropout(sequence_output)
    pool_output = torch.mean(sequence_output, axis=1)
    logits = bert_cls_model.dense(pool_output)
    print(logits)
    
    # 这样计算出来是错的
    logits = bert_cls_model(**inputs)
    print(logits)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    于是我又在模型类的定义里打印了各个阶段的结果,如上第一段代码中的print,发现从bert_out的打印结果来看全都是错的。

    更进一步地,为了确认是不是模型加载权重的时候出现了问题(比如加载权重后的模型被重新初始化了),我又在模型定义代码里打印了模型的参数值,确认参数值也是没有问题的。这就让我感到有些匪夷所思了。

    我又按照同样的对比方法,在模型里边打印一次,单独拿出来打印一次,试着找出问题所在,这次是从一开始embedding开始,结果发现在模型内部和外部打印embedding的结果是一致的:

    # 这样打印的结果是正确的
    bert_cls_model.bert.embeddings(input_ids=inputs['input_ids'], token_type_ids=inputs['token_type_ids'])
    
    # 在模型的forward方法里打印embedding的结果同样是正确的
    
    • 1
    • 2
    • 3
    • 4

    更奇怪的是,我将embedding的结果输入给encoder手动计算,出来的sequence_out就变成正确的了:

    class BertClassifier(torch.nn.Module):
        def __init__(self, bert_model, num_classes):
            super(BertClassifier, self).__init__()
            self.bert = bert_model
            self.dropout = torch.nn.Dropout(0.2)
            self.dense = torch.nn.Linear(768, num_classes)
            
        def forward(
            self,
            input_ids=None,
            token_type_ids=None,
            attention_mask=None,
            labels=None,
        ):
        	# 直接调用self.bert计算出来结果是错误的
            # bert_out = self.bert(input_ids, token_type_ids, attention_mask, output_attentions=False)
    
    		# 手动以此调用embedding和encoder,就算出来的结果就是正确的了
            embedding_res = self.bert.embeddings(input_ids=input_ids, token_type_ids=token_type_ids)
            encoder_out = self.bert.encoder(embedding_res)
            sequence_output = encoder_out[0]
    
            sequence_output = self.dropout(sequence_output)
            pool_output = torch.mean(sequence_output, axis=1)
            
            logits = self.dense(pool_output)
            # print(logits)
            
            loss = None
            loss_fct = torch.nn.CrossEntropyLoss()
            
            if labels is not None:
                # labels = label.long()
                loss = loss_fct(logits, labels.view(-1))
            
            return loss if loss is not None else logits
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    最后我又额外检查了一遍两个版本源码的差别,也没有发现什么端倪,感觉修改的地方都是些写法的差异,不应该有能够造成这个问题的地方。

    解决的话,目前就是把transformers的版本降下来,或者像最后这样手动执行计算,还没有发现真正出问题的地方在哪里,如果有哪位也遇到这个问题并且有效解决了的话,还请在评论区指出,谢谢。

  • 相关阅读:
    嵌入式分享合集109
    人工智能AI绘画,Stable Diffusion保姆级教程,小白也可以掌握SD使用
    常驻巨噬细胞诱导的纤维化在胰腺炎性损伤和PDAC中具有不同的作用
    c++实现组播和广播的发送和接收端
    CPU中的MESI协议(Intel)
    2023秋招笔试算法Python3题解
    【仿牛客网笔记】Spring Boot实践,开发社区登录模块-显示登录信息
    aspnetcore使用websocket实时更新商品信息
    java计算机毕业设计物流公司停车位管理源程序+mysql+系统+lw文档+远程调试
    云原生学习笔记-1-docker
  • 原文地址:https://blog.csdn.net/weixin_44826203/article/details/128180977