默认的,我们最常用的resume方式:
- if args.resume:
- checkpoint = torch.load(resume_path, map_location='cpu')
- model_without_ddp.load_state_dict(checkpoint['model'])
- print("Resume checkpoint %s" % resume_path)
-
- if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):
- optimizer.load_state_dict(checkpoint['optimizer'])
- args.start_epoch = checkpoint['epoch'] + 1
- if 'scaler' in checkpoint:
- loss_scaler.load_state_dict(checkpoint['scaler'])
- print("With optim & sched!")
-
- del checkpoint
在resume模型的时候,可能会遇到某些层是没有的,或者你改变了某些层的维度,从而导致model_state_dict()错误,所以此时的解决办法为:忽略这些层,不加载它们
如果是你的model出现了某些新的维度,但是resume model中并没有
直接使用 strict 参数置为 False 即可:
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
此时你的model的某些层的参数发生变化了,但是resume的model还是原来的model,那么就忽略这些维度不匹配的层,只加载维度相同的层:
- if args.resume:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model_state_dict = model.state_dict()
-
- # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
- filtered_state_dict = {}
- for k, v in checkpoint['model'].items():
- model_key = 'module.' + k
- # 【在分布式训练或使用 DataParallel 时,模型的状态字典中的参数名称通常会带有 module. 前缀】
- if model_key in model_state_dict and model_state_dict[model_key].shape == v.shape:
- filtered_state_dict[model_key[7:]] = v
- else:
- print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[model_key].shape}")
-
- # 加载过滤后的状态字典
- model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
-
- del checkpoint
- if args.resume:
- checkpoint = torch.load(args.resume, map_location='cpu')
- model_state_dict = model_without_ddp.state_dict()
- # 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)
- filtered_state_dict = {}
- for k, v in checkpoint['model'].items():
- if k in model_state_dict and model_state_dict[k].shape == v.shape:
- filtered_state_dict[k] = v
- else:
- # print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[k].shape}")
- print("Error")
- model_params = list(model_without_ddp.parameters())
- ema_params = copy.deepcopy(model_params)
- # 加载过滤后的状态字典
- model_without_ddp.load_state_dict(filtered_state_dict, strict=False)
-
- del checkpoint
上面没有进行优化器的resume,是因为对维度不匹配的情况,再resume 优化器很麻烦,感觉意义也不大