• 【迁移学习】


    1 迁移学习的思路

    迁移学习的思路是利用预训练模型的卷积部分(卷积基)提取数据集的图片特征,然后重新练最后的全连接部分(分类器),迁移学习的特征提取部分(卷积基)不能发生变化。

    2 迁移学习的步骤

    迁移学习的思路有3步:
    (1)冻结预训练模型的卷积基
    (2)根据问题重新设置分类器,如需要分2类,则out_features=2
    (3)用自己的数据训练设置好的分类器,注意:只优化分类器参数

    3 具体步骤

    torchvision提供了可以加载的预训练模型:

    alexnet
    convnex
    densenet
    efficientnet
    feature_extraction
    googlenet
    inception
    mnasnet
    mobilenet
    mobilenetv2
    mobilenetv3
    regnet
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    使用代码:

    import torchvision
    方法1:
    model = torchvision.models.vgg16(pretrained=True)	# pretrained=True表式仅仅加载网络结构,而不加载网络参数
    # 方法2
    model = models.vgg16(weights= models.VGG16_Weights.DEFAULT)
    print(model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出如下:

    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
    ...
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    冻结卷积基的参数,避免模型参数被破坏,准确率下降

    for param in model.features.parameters():
        param.requires_grad = False
    
    model.classifier[-1].out_features = 4	# 4分类
    
    model = model.to(device)	# 模型传入型芯片,一般GPU上
    loss_fn = nn.CrossEntropyLoss()	#根据具体问题自定义
    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.0001)	#注意这里只优化分类器参数
    
    ....接训练代码
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  • 相关阅读:
    HTML期末作业,基于html实现中国脸谱传统文化网站设计(5个页面)
    AI视频风格转换:Stable Diffusion+EBSynth
    基于Springboot的特产销售平台设计与实现毕业设计源码091036
    MES管理系统中的质量管理活动是什么
    一文拿捏对象内存布局及JMM(JAVA内存模型)
    qt 绘图
    Gartner公布《2023中国ICT技术成熟度曲线》,得帆信息入选低代码代表厂商
    宏任务、微任务理解
    Matlab 实用代码集
    Zabbix
  • 原文地址:https://blog.csdn.net/m0_46256255/article/details/133184650