确实可以用,有人写了论文证明为什么可以用;
先看正常的模型加载(不是预训练),同样的模型保存,再用同样的模型加载;
不该读取的就不要读取,只读取有用的部分,因为参数存储是用dictionary存储的,所以只需要在key上下功夫,把不需要的部分去掉:
from ASL_reproduce-master
if args.model_path: # make sure to load pretrained ImageNet model
state = torch.load(args.model_path, map_location='cpu')
filtered_dict = {k: v for k, v in state['model'].items() if
(k in model.state_dict() and 'head.fc' not in k)}
model.load_state_dict(filtered_dict, strict=False)
比如用自带的resnet50,
import torch
import torchvision
model = torchvision.models.resnet50(pretrained=True)
my_out_features = 80
model.fc = torch.nn.Sequential(
torch.nn.Linear(
in_features=2048,
out_features=1
),
torch.nn.Sigmoid()
)
[1] 官网教程 SAVING AND LOADING MODELS
[2] ResNet50 with PyTorch