• 第九章 预训练模型与自己模型参数不匹配和模型微调的具体实现


    导入预训练模型在通常情况下都能加快模型收敛,提升模型性能。但根据实际任务需求,自己搭建的模型往往和通用的Backbone并不能做到网络层的完全一致,无非就是少一些层和多一些层两种情况。

    1. 自己模型层数较少

    net = ...   # net为自己的模型
    save_model = torch.load('path_of_pretrained_model') # 获取预训练模型字典(键值对)
    model_dict = net.state_dict() # 获取自己模型字典(键值对)
    # 新定义字典,用来获取自己模型中对应层的预训练模型中的参数
    state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()} 
    model_dict.update(state_dict) # 更新自己模型字典中键值对
    net.load_state_dict(model_dict) # 加载参数
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    其中update:对于state_dict和model_dict都有的键值对,前者对应的值会替换后者,若前者有后者没有的键值对,则会添加这些键值对到后者字典中。

    2. 自己模型层数较多

    model_dict = model.state_dict() # 自己模型字典
    model_dict.update(pretrained_model) # 直接将预训练模型参数更新进来
    model.load_state_dict(model_dict) # 加载
    
    • 1
    • 2
    • 3

    结果与预训练模型对应的层权重被加载了,其它层则为默认初始化。

    由于模型参数都为字典形式存在,可以用字典的增删方式进行更灵活的操作

    特征提取与微调

    迁移学习是一种有效的机器学习方法,尤其是在数据不足的情况下。它通过使用在大型数据集(如 ImageNet)上预训练的模型,将这些模型的知识应用于新的、相对较小的数据集上。下面,我将详细解释两种主要的迁移学习策略,并提供相应的代码示例。

    1. 特征提取(Feature Extraction)

    在这种策略中,你使用预训练模型的卷积层(即模型的前几层)来作为固定的特征提取器。然后,你只需添加一些新的可训练层(通常是全连接层),以便根据新的数据集进行预测。

    代码示例

    假设我们要在一个新的数据集上使用预训练的 ResNet 模型进行特征提取:

    import torch
    import torch.nn as nn
    from torchvision import models
    
    # 加载预训练的 ResNet 模型
    model = models.resnet18(pretrained=True)
    # 冻结所有卷积层的参数
    for param in model.parameters():
        param.requires_grad = False
    
    # 替换最后的全连接层
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2) # 假设我们的新数据集有2个类别
    # 现在,只有 model.fc 层的参数会在训练中更新
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在这个示例中,我们首先加载了预训练的 ResNet 模型,并冻结了它的所有卷积层。然后,我们替换了最后的全连接层以适应新的分类任务(假设有两个类别)。在训练过程中,只有这个新添加的全连接层的参数会更新。

    2. 微调(Fine-Tuning)

    微调涉及到在预训练模型的基础上进一步训练整个模型(或模型的大部分层)。这通常是在特征提取的基础上进行的,即首先使用特征提取策略训练模型,然后解冻整个模型的多个层,并对整个模型进行进一步的训练。

    代码示例

    继续使用上面的 ResNet 示例,我们现在将进行微调:

    # 之前的代码省略...
    # 解冻所有层
    for param in model.parameters():
        param.requires_grad = True
    # 现在整个模型的参数都会在训练中更新
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在微调阶段,所有层的参数都会被更新。这通常是在特征提取阶段之后进行,特别是当新数据集与原始数据集在特征上有较大差异时。

    添加、删除或修改字典中的项

    在 PyTorch 中,模型的参数是以字典形式存储的,这使得我们可以利用 Python 字典的特性来灵活地处理模型参数。例如,我们可以添加、删除或修改字典中的项,以此来自定义模型的参数。以下是一些常见操作的示例:

    修改

    示例:修改特定层的参数

    假设您想要修改预训练模型中某一层的参数,比如将第一层卷积层的权重全部设置为零。

    import torch
    import torchvision.models as models
    
    # 加载预训练模型
    model = models.resnet18(pretrained=True)
    # 将第一层卷积层的权重设置为零
    torch.nn.init.constant_(model.conv1.weight, 0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这个示例中,我们使用 torch.nn.init.constant_ 方法直接修改了模型的第一层卷积层(conv1)的权重,将其全部设置为零。

    示例:修改层的属性

    除了修改层的参数外,您还可以修改层的其他属性。例如,您可能想要更改卷积层的步长(stride)或填充(padding)。

    # 改变第一层卷积层的步长和填充
    model.conv1.stride = (2, 2)
    model.conv1.padding = (1, 1)
    
    • 1
    • 2
    • 3

    在这个示例中,我们修改了模型的第一层卷积层的步长和填充属性。

    注意事项

    • 在修改模型的参数或属性时,请确保您的修改不会导致模型架构的不一致。例如,更改卷积层的核大小或步长可能会影响模型中后续层的输入尺寸。
    • 修改模型参数通常需要对模型的工作原理有深入的理解,以避免意外地破坏模型的性能。

    通过这些修改,您可以对模型进行微调,使其更好地适应特定的任务或数据集。这种能力在进行模型实验和优化时非常有价值。

    示例 1:删除特定层的参数

    假设你想要从预训练模型中删除某些层的参数。这可以通过删除字典中相应的键值对来实现。

    import torch
    import torchvision.models as models
    
    # 加载预训练模型
    model = models.resnet18(pretrained=True)
    
    # 获取模型的 state_dict
    model_dict = model.state_dict()
    
    # 假设我们想要删除第一层卷积层的参数
    del model_dict['conv1.weight']
    del model_dict['conv1.bias']
    
    # 更新模型的 state_dict
    model.load_state_dict(model_dict, strict=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    在这个示例中,我们首先加载了预训练的 ResNet18 模型,并获取其状态字典。然后,我们删除了第一层卷积层(conv1)的权重和偏置参数。最后,我们使用 load_state_dict 更新模型的参数,strict=False 表示我们允许不匹配的项存在。

    示例 2:添加新层的参数

    如果你想向模型中添加新的层,并为其初始化参数,可以直接向状态字典中添加新的键值对。

    import torch.nn as nn
    
    # 继续之前的模型
    # 添加一个新的全连接层
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
                   nn.Linear(num_ftrs, 500),
                   nn.ReLU(),
                   nn.Linear(500, 2)
               )
    
    # 初始化新层的参数
    nn.init.xavier_normal_(model.fc[0].weight)
    nn.init.constant_(model.fc[0].bias, 0)
    nn.init.xavier_normal_(model.fc[2].weight)
    nn.init.constant_(model.fc[2].bias, 0)
    
    # 获取新的模型 state_dict
    new_model_dict = model.state_dict()
    
    # 可以选择性地将新的 state_dict 保存下来
    # torch.save(new_model_dict, 'modified_model.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
  • 相关阅读:
    python学习05协程_2
    Android 通话常见错误码汇总
    了解舵机以及MG996R的控制方法
    linux系统 删除文件命令
    132、LeetCode-72.编辑距离
    【QT HTTP】使用QtNetwork模块制作基于HTTP请求的C/S架构
    支持I2S数字音频接口;音频功放芯片NTP8835C
    TCP、UDP API调用(实时聊天)
    前端面试题记录
    【云原生之K8S】K8S管理工具kubectl 详解
  • 原文地址:https://blog.csdn.net/weixin_44302770/article/details/134516822