• resume不严格加载model、避免某些层维度不一致导致错误


    默认的,我们最常用的resume方式:

    1. if args.resume:
    2. checkpoint = torch.load(resume_path, map_location='cpu')
    3. model_without_ddp.load_state_dict(checkpoint['model'])
    4. print("Resume checkpoint %s" % resume_path)
    5. if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):
    6. optimizer.load_state_dict(checkpoint['optimizer'])
    7. args.start_epoch = checkpoint['epoch'] + 1
    8. if 'scaler' in checkpoint:
    9. loss_scaler.load_state_dict(checkpoint['scaler'])
    10. print("With optim & sched!")
    11. del checkpoint

    在resume模型的时候,可能会遇到某些层是没有的,或者你改变了某些层的维度,从而导致model_state_dict()错误,所以此时的解决办法为:忽略这些层,不加载它们

    情况一:

    如果是你的model出现了某些新的维度,但是resume model中并没有

    直接使用 strict 参数置为 False 即可:

    model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

    情况二:

    此时你的model的某些层的参数发生变化了,但是resume的model还是原来的model,那么就忽略这些维度不匹配的层,只加载维度相同的层:

    1)分布式保存的ckpt进行resume:

    1. if args.resume:
    2. checkpoint = torch.load(args.resume, map_location='cpu')
    3. model_state_dict = model.state_dict()
    4. # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
    5. filtered_state_dict = {}
    6. for k, v in checkpoint['model'].items():
    7. model_key = 'module.' + k
    8. # 【在分布式训练或使用 DataParallel 时,模型的状态字典中的参数名称通常会带有 module. 前缀】
    9. if model_key in model_state_dict and model_state_dict[model_key].shape == v.shape:
    10. filtered_state_dict[model_key[7:]] = v
    11. else:
    12. print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[model_key].shape}")
    13. # 加载过滤后的状态字典
    14. model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
    15. del checkpoint

    2)正常保存的ckpt进行resume(一般使用这个就可以了):

    1. if args.resume:
    2. checkpoint = torch.load(args.resume, map_location='cpu')
    3. model_state_dict = model_without_ddp.state_dict()
    4. # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
    5. filtered_state_dict = {}
    6. for k, v in checkpoint['model'].items():
    7. if k in model_state_dict and model_state_dict[k].shape == v.shape:
    8. filtered_state_dict[k] = v
    9. else:
    10. # print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[k].shape}")
    11. print("Error")
    12. model_params = list(model_without_ddp.parameters())
    13. ema_params = copy.deepcopy(model_params)
    14. # 加载过滤后的状态字典
    15. model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
    16. del checkpoint

    上面没有进行优化器的resume,是因为对维度不匹配的情况,再resume 优化器很麻烦,感觉意义也不大

  • 相关阅读:
    计算机毕业设计基于Springboot+vue口腔牙科诊所管理系统
    blender安装cats-blender-plugin-0-19-0插件,导入pmx三维模型
    vpn概述总结
    如何在 C# 程序中注入恶意 DLL?
    谣言检测(PLAN)——《Interpretable Rumor Detection in Microblogs by Attending to User Interactions》
    (附源码)ssm考试题库管理系统 毕业设计 069043
    vue实战入门后台篇六:springboot+mybatis实现网站后台-前端登录功能对接
    git 重置到某次提交
    通关算法题之 ⌈二叉树⌋ 上
    第04章 经典卷积神经网络模型
  • 原文地址:https://blog.csdn.net/weixin_43135178/article/details/140015945