• pytorch迁移学习载入部分权重


            载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习。使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即小数据在使用迁移学习后仍然能够保持良好的性能,避免的小数据带来的数据不足,模型训练不充分的问题。载入自己的训练的权重在模型测试和继续训练时使用较多,模型测试载入权重就不说了,继续训练是指假设设置epoch为500,训练接受后,发现模型仍然没有收敛,那么你就可以载入epoch为500时的训练权重,再训练500的epoch,这样你对模型就总共训练了1000个epoch,而不需要在发现模型未收敛时,又重头去训练1000个epoch。

    壹.载入全部权重

    假设模型定义如下,以VGG为例:权重文件为.pth后缀文件:

    1. import torch
    2. import torch.nn as nn
    3. class VGG(nn.Module):
    4. def __init__(self, features,num_classes=1000):
    5. super(VGG, self).__init__()
    6. self.features = features
    7. self.classifier = nn.Sequential(
    8. nn.Linear(512 * 7 * 7, 4096),
    9. nn.ReLU(True),
    10. nn.Dropout(p=0.5),
    11. nn.Linear(4096, 4096),
    12. nn.ReLU(True),
    13. nn.Dropout(p=0.5),
    14. nn.Linear(4096, num_classes)
    15. )
    16. def forward(self, x):
    17. x5 = self.features(x)
    18. x5= torch.flatten(x5, start_dim=1)
    19. x5= self.classifier(x5)
    20. return x5
    21. def make_features(cfg: list):
    22. layers = []
    23. in_channels = 3
    24. for v in cfg:
    25. if v == "M":
    26. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    27. else:
    28. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
    29. layers += [conv2d, nn.ReLU(True)]
    30. in_channels = v
    31. return nn.Sequential(*layers)
    32. cfgs = {
    33. 'vgg16': [64,64'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    34. }
    35. def vgg(model_name="vgg16", **kwargs):
    36. assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    37. cfg = cfgs[model_name]
    38. model = VGG(make_features(cfg) ,**kwargs)
    39. return model
    40. if __name__=='__main__':
    41. device=torch.device('cuda:0')
    42. net=vgg()
    43. net.to(device)
    44. summary(net,(3,224,224))
    45. x=torch.rand(1,3,224,224).to(device)
    46. out=net(x)
    47. print(out.shape)

    载入模型权重:

    1. model_name = "vgg16"
    2. net = vgg(model_name=model_name, num_classes=102)
    3. weight_path='./vgg16_12_BNsig_1_best.pth'
    4. net.load_state_dict(torch.load(weight_path,map_location=device))

    这样模型就载入了全部的权重,文中的权重是我自己训练的。

    贰.载入部分权重

    在很多情况下我们根据实际情况修改了部分网络结构,导致官方的预训练权重或者自己以前训练的权重报错。

    假设在现有模型上增加一个模块:

    1. import torch
    2. import torch.nn as nn
    3. class VGG(nn.Module):
    4. def __init__(self, features,num_classes=1000):
    5. super(VGG, self).__init__()
    6. self.rnn3 = nn.Sequential(
    7. nn.Conv2d(64, 64, 3, 1, 1),
    8. nn.Tanh())
    9. self.features = features
    10. self.classifier = nn.Sequential(
    11. nn.Linear(512 * 7 * 7, 4096),
    12. nn.ReLU(True),
    13. nn.Dropout(p=0.5),
    14. nn.Linear(4096, 4096),
    15. nn.ReLU(True),
    16. nn.Dropout(p=0.5),
    17. nn.Linear(4096, num_classes)
    18. )
    19. def forward(self, x):
    20. x1=self.rnn3(x)
    21. x5 = self.features(x1)
    22. x5= torch.flatten(x5, start_dim=1)
    23. x5= self.classifier(x5)
    24. return x5

    再次载入模型时就会报错:

    因为在预训练权重文件中并没有rnn3的权重,所以报错为missing key。

     解决方法,从预训练权重中挑出现有模型的权重,并使用预训练权重初始化现有模型的权重,即完成现有模型的权重初始化。

    假设现有模型的权重key值有{conv1,conv2,conv3,conv44,conv5},预训练权重的key值有{conv1,conv2,conv3,conv4,conv5}

    那么我们新建一个权重字典,将key值在现有模型和预训练模型中都存在的保存下来,然后用新建的权重字典载入现有模型,即完成模型的初始化。

    1. model_name = "vgg16"
    2. net = vgg(model_name=model_name, num_classes=102)
    3. weight_path='./vgg16_12_BNsig_1_best.pth'
    4. # 抽出预训练模型中的K,V
    5. pretrain_model=torch.load(weight_path,map_location=device)
    6. # 抽出现有模型中的K,V
    7. model_dict=net.state_dict()
    8. # 新建权重字典,并更新
    9. state_dict={k:v for k,v in pretrain_model.items() if k in model_dict.keys()}
    10. # 更新现有模型的权重字典
    11. model_dict.update(state_dict)
    12. # 载入更新后的权重字典
    13. net.load_state_dict(model_dict)

    叁.载入部分权重并冻结载入权重的部分

    载入部分和2是一样,冻结权重即意味着权重在训练过程中不更新,那么将权重的requires_grad = False即可。

    沿用2的部分,即我们现在载入的权重中只有rnn3是预训练权重中没有,那么我们就冻结其余的权重,只训练rnn3即可。

    1. model_name = "vgg16"
    2. net = vgg(model_name=model_name, num_classes=102, init_weights=False)
    3. weight_path='./vgg16_12_BNsig_1_best.pth'
    4. # 抽出预训练模型中的K,V
    5. pretrain_model=torch.load(weight_path,map_location=device)
    6. # 抽出现有模型中的K,V
    7. model_dict=net.state_dict()
    8. print(model_dict.keys())
    9. # 新建权重字典,并更新
    10. state_dict={k:v for k,v in pretrain_model.items() if k in model_dict.keys()}
    11. print(state_dict.keys())
    12. # 更新现有模型的权重字典
    13. model_dict.update(state_dict)
    14. # 载入更新后的权重字典
    15. net.load_state_dict(model_dict)
    16. # 冻结权重,即设置该训练参数为不可训练即可
    17. for name,para in net.named_parameters():
    18. if name in state_dict:
    19. para.requires_grad=False
    20. # 更新可训练参数
    21. para=[para for para in net.parameters() if para.requires_grad]
    22. # 更新后的可训练参数就只有rnn,权重有两个,一个是weight,一个是bias
    23. print(para)

  • 相关阅读:
    实战干货!用 Python 爬取股票实时数据!
    [ubuntu][转载]ubuntu挂载nas并实现开机自动挂载
    2022 年坑过我的 JAVA 面试题
    锁的优化机制了解吗?
    一致性检验评价方法kappa
    QDataStream中 << 和 >> 输入输出重载的理解
    行业现状?互联网公司为什么宁愿花20k招人,也不愿涨薪留住老员工~
    深度学习推荐系统架构、Sparrow RecSys项目及深度学习基础知识
    【架构设计】CAP理论、BASE理论
    简述直线模组的发展前景
  • 原文地址:https://blog.csdn.net/qq_51570094/article/details/126706613