• RuntimeError: Error(s) in loading state_dict for BASE_Transformer


    最近跑一个深度学习变化检测的项目BIT_CD,严格按照作者的说明页进行训练和测试,但是跑出来的模型就是无法正常工作,而用作者的预训练模型就正常工作,百思不得其解,根据错误,逐步调试,输出,总算是找到了问题的所在!

    其实这个问题如果对于老手,估计一下子就解决了,但是对于刚刚接触深度学习的新人,要发现并且解决这个问题确实需要费不少功夫。

    那么,这里就记录一下这个问题的解决过程!

    首先,错误的具体内容如下:

    1. Traceback (most recent call last):
    2. File "D:/Tools/MyScript/BIT_CD/eval_cd.py", line 58, in
    3. main()
    4. File "D:/Tools/MyScript/BIT_CD/eval_cd.py", line 54, in main
    5. model.eval_models(checkpoint_name=args.checkpoint_name)
    6. File "D:\Tools\MyScript\BIT_CD\models\evaluator.py", line 158, in eval_models
    7. self._load_checkpoint(checkpoint_name)
    8. File "D:\Tools\MyScript\BIT_CD\models\evaluator.py", line 70, in _load_checkpoint
    9. self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
    10. File "D:\MyItems\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1483, in load_state_dict
    11. self.__class__.__name__, "\n\t".join(error_msgs)))
    12. RuntimeError: Error(s) in loading state_dict for BASE_Transformer:
    13. size mismatch for transformer_decoder.layers.0.0.fn.fn.to_q.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
    14. size mismatch for transformer_decoder.layers.0.0.fn.fn.to_k.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
    15. size mismatch for transformer_decoder.layers.0.0.fn.fn.to_v.weight: copying a param with shape torch.Size([512, 32]) from checkpoint, the shape in current model is torch.Size([64, 32]).
    16. size mismatch for transformer_decoder.layers.0.0.fn.fn.to_out.0.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([32, 64]).

    根据页面提示的错误,可以知道是加载模型中的state_dict维度不匹配。

    于是,我尝试输出预训练模型,和我自己训练的模型的内容进行比较。

     结果,我一脸懵逼,看不出个所以然来,维度太高,括号太多,根本没法看。

    但是可以注意到,两个模型中都存在model_G_state_dict关键字。

    因为这样实在不好看维度,于是我换了一个输出方式,查看了mydict的类型,直接输出维度。

    1. mydict = checkpoint['model_G_state_dict']
    2. for k, v in mydict.items():
    3. print (k, '========' ,v.shape)

    此时问题一目了然,中间层的维度不对,跟报错内容一致。

    作者的预训练模型中间某个维度是64,32 而我训练出来的模型是512,32。

    根据出错的提示

    BASE_Transformer

    以及transformer_decoder.layers,继续寻找出错位置。

    经过整个工程搜索,来到代码位置:

    1. class BASE_Transformer(ResNet):
    2. """
    3. Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
    4. """
    5. def __init__(self, input_nc, output_nc, with_pos, resnet_stages_num=5,
    6. token_len=4, token_trans=True,
    7. enc_depth=1, dec_depth=1,
    8. dim_head=64, decoder_dim_head=64,
    9. tokenizer=True, if_upsample_2x=True,
    10. pool_mode='max', pool_size=2,
    11. backbone='resnet18',
    12. decoder_softmax=True, with_decoder_pos=None,
    13. with_decoder=True):
    14. super(BASE_Transformer, self).__init__(input_nc, output_nc,backbone=backbone,
    15. resnet_stages_num=resnet_stages_num,
    16. if_upsample_2x=if_upsample_2x,
    17. )

    断点调试,跳出,得到位置:

    self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)

    输出这个self.net_G查看,可以发现此处的模型与预训练的模型已经不一致,那为何不一致呢?

    关键跳转函数:

    1. def define_G(args, init_type='normal', init_gain=0.02, gpu_ids=[]):
    2. if args.net_G == 'base_resnet18':
    3. net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False)
    4. elif args.net_G == 'base_transformer_pos_s4':
    5. net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
    6. with_pos='learned')
    7. elif args.net_G == 'base_transformer_pos_s4_dd8':
    8. net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
    9. with_pos='learned', enc_depth=1, dec_depth=8)
    10. elif args.net_G == 'base_transformer_pos_s4_dd8_dedim8':
    11. net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4,
    12. with_pos='learned', enc_depth=1, dec_depth=8, decoder_dim_head=8)
    13. else:
    14. raise NotImplementedError('Generator model name [%s] is not recognized' % args.net_G)
    15. return init_net(net, init_type, init_gain, gpu_ids)

    最终定位到原因,原来训练模型和测试模型使用的模型不同!

    于是这就好办了,修改测试代码模型参数即可。

    1. parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str,
    2. help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')

    至此,成功训练并测试模型,问题解决!

    虽然是个并不复杂的问题,但是我花了一天半的时间才完全解决,解决的过程中经过了无数次的尝试,测试和分析,最终成功解决了问题!

     

    回过头看,这其实是一个很简单的问题,但是我仍然很高兴自己能够逐步分析,测试,最终找到正确的解决办法,而做到更深入地理解了代码。

    本次遇到的问题比较特别,在网上几乎找不到答案,但是找到了跟我遇到相同问题的人,不管怎么说,这都是一件值得纪念的事情。

  • 相关阅读:
    .NET 7 RC 2 发布,倒计时一个月发布正式版
    养殖废水总氮超标的解决方法
    LC1713. 得到子序列的最少操作次数(java - 动态规划)
    PHP 可用的函数
    稳定的排序算法:直接插入排序和冒泡排序 (c++实现)
    java后端:nacos 实现动态配置
    《杂七杂八》:Linux(Kali)终端快捷键
    java专题练习(抢红包)
    【GO for java programmers】面向Java开发者的GO编程3_go for java(1)
    【Unity细节】如何让组件失活而不是物体失活
  • 原文地址:https://blog.csdn.net/autumoonchina/article/details/126663542