• resume不严格加载model、避免某些层维度不一致导致错误


    默认的,我们最常用的resume方式:

    1. if args.resume:
    2. checkpoint = torch.load(resume_path, map_location='cpu')
    3. model_without_ddp.load_state_dict(checkpoint['model'])
    4. print("Resume checkpoint %s" % resume_path)
    5. if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):
    6. optimizer.load_state_dict(checkpoint['optimizer'])
    7. args.start_epoch = checkpoint['epoch'] + 1
    8. if 'scaler' in checkpoint:
    9. loss_scaler.load_state_dict(checkpoint['scaler'])
    10. print("With optim & sched!")
    11. del checkpoint

    在resume模型的时候,可能会遇到某些层是没有的,或者你改变了某些层的维度,从而导致model_state_dict()错误,所以此时的解决办法为:忽略这些层,不加载它们

    情况一:

    如果是你的model出现了某些新的维度,但是resume model中并没有

    直接使用 strict 参数置为 False 即可:

    model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

    情况二:

    此时你的model的某些层的参数发生变化了,但是resume的model还是原来的model,那么就忽略这些维度不匹配的层,只加载维度相同的层:

    1)分布式保存的ckpt进行resume:

    1. if args.resume:
    2. checkpoint = torch.load(args.resume, map_location='cpu')
    3. model_state_dict = model.state_dict()
    4. # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
    5. filtered_state_dict = {}
    6. for k, v in checkpoint['model'].items():
    7. model_key = 'module.' + k
    8. # 【在分布式训练或使用 DataParallel 时,模型的状态字典中的参数名称通常会带有 module. 前缀】
    9. if model_key in model_state_dict and model_state_dict[model_key].shape == v.shape:
    10. filtered_state_dict[model_key[7:]] = v
    11. else:
    12. print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[model_key].shape}")
    13. # 加载过滤后的状态字典
    14. model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
    15. del checkpoint

    2)正常保存的ckpt进行resume(一般使用这个就可以了):

    1. if args.resume:
    2. checkpoint = torch.load(args.resume, map_location='cpu')
    3. model_state_dict = model_without_ddp.state_dict()
    4. # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
    5. filtered_state_dict = {}
    6. for k, v in checkpoint['model'].items():
    7. if k in model_state_dict and model_state_dict[k].shape == v.shape:
    8. filtered_state_dict[k] = v
    9. else:
    10. # print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[k].shape}")
    11. print("Error")
    12. model_params = list(model_without_ddp.parameters())
    13. ema_params = copy.deepcopy(model_params)
    14. # 加载过滤后的状态字典
    15. model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
    16. del checkpoint

    上面没有进行优化器的resume,是因为对维度不匹配的情况,再resume 优化器很麻烦,感觉意义也不大

  • 相关阅读:
    gazebo中给机器人添加16线激光雷达跑LIO-SAM
    下一代实时数据库:Apache Doris 【一】简介
    MT3030 天梯赛
    Nacos注册中心8-Server端(处理注册请求)
    C#/.NET/.NET Core优秀项目和框架精选(2023年10月更新,项目分类已整理完成欢迎大家踊跃提交PR一起完善让优秀的项目和框架不被埋没)
    Java入门基础
    安装TimeGen波形绘图软件
    cuda编程之共享内存的bank冲突
    通达OA通用版V12的表单js定制开发,良好实践总结-持续更新
    Docker 安装 Nginx 容器 (完整详细版)
  • 原文地址:https://blog.csdn.net/weixin_43135178/article/details/140015945