• PyTorch模型定义和相关使用


    PyTorch模型定义

    1. PyTorch模型定义的方式

    模型在深度学习中扮演着重要的角色,好的模型极大地促进了深度学习的发展进步,比如:

    • CNN:解决了图像、视频处理中的问题
    • RNN/LSTM:解决了序列数据处理的问题
    • GNN:在图模型上发挥着重要作用

    当了解使用了深度学习的项目时,一般首先需要了解该项目使用了哪些模型,因此本节将主要讲解模型定义的方式、看懂GitHub上的模型定义、根据实际需求灵活选取模型定义方式。

    Module类是torch.nn模块中提供的一个模型构造类,是所有神经网络模块的积累,可以继承它来定义我们想要的模型。

    1.1 知识回顾

    PyTorch模型定义应包括两个主要部分:

    • 各个部分的初始化,即__init__方法;
    • 数据流向定义,即forward方法

    基于nn.Module,可以通过Sequential,ModuleList和ModuleDict三种方式定义PyTorch模型

    1.2 Sequential

    对应nn.Sequential(),当模型的前向计算为简单串联各个的计算时,Sequential类可以通过更加简单的方式定义模型。它可以通过接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一按添加的顺序计算。

    下面结合原生Module给出Sequential的定义:

    class MySequential(nn.Module):
        from collections import OrderedDict
        def __init__(self, *args):
            super(MySequential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
                for key, module in args[0].items():
                    self.add_module(key, module)  
                    # add_module方法会将module添加进self._modules(一个OrderedDict)
            else:  # 传入的是一些Module
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)
        def forward(self, input):
            # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成
            for module in self._modules.values():
                input = module(input)
            return input
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    下面结合实例给出Sequential定义模型的方法:

    • 直接排列
    import torch.nn as nn
    net = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 10), 
            )
    print(net)
    
    # output
    Sequential(
      (0): Linear(in_features=784, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 使用OrderedDict:
    import collections
    import torch.nn as nn
    net2 = nn.Sequential(collections.OrderedDict([
              ('fc1', nn.Linear(784, 256)),
              ('relu1', nn.ReLU()),
              ('fc2', nn.Linear(256, 10))
              ]))
    print(net2)
    
    # output
    Sequential(
      (fc1): Linear(in_features=784, out_features=256, bias=True)
      (relu1): ReLU()
      (fc2): Linear(in_features=256, out_features=10, bias=True)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    可以看到,使用Sequential定义模型非常简单、易读,同时使用Sequential定义的模型不需要再写forward。

    但同时,使用Sequential会使得模型定义丧失灵活性,因为Sequential的forward方法已经定义好,没法修改模型的中间结果

    1.3 ModuleList

    对应nn.ModuleList(),接收一个子模块(或层,即继承自nn.Module类)的列表作为输入,然后可以类似List使用append或extend添加新的子模块,同时也可以类似于List方式去访问其中的子模块。

    注:ModuleList并没有定义一个网络,它只是将不同的子模块存储起来,不像Sequential那样定义了forward方法,因此需要自己根据实际情况去实现相关的forward方法。

    net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
    net.append(nn.Linear(256, 10))  # 类似List的append操作
    print(net[-1])  # 类似List的索引访问
    print(net)
    
    • 1
    • 2
    • 3
    • 4

    1.4 ModuleDict

    对应nn.ModuleDict(),其作用和ModuleList类型,只是ModuleDict能够更方便地为神经网络层添加名称(类似于python中的dict操作)

    net = nn.ModuleDict({
     	'linear': nn.Linear(784, 256),
        'act': nn.ReLU(),
    })
    net['output'] = nn.Linear(256, 10)  # add new module like python dict
    print(net['linear'])	# access dict value like python dict
    print(net)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    1.5 总结

    • Sequential:使用该方法不需要写__init__forward方法
    • ModuleList和ModuleDict:适用于某个相同的层多次出现,它们的使用分别和python的list和dict使用类似

    2. 利用模型快速搭建复杂网络

    在实际场景中,如果一个模型的深度非常大,即需要非常多层时,使用Sequential定义模型结构就需要添加多个层,使用起来不方便。

    而对于大部分的模型结构(如ResNet、DenseNet等),可以发现,虽然模型有很多层,但是其中的层有很多重复出现的结构。

    因此可以考虑将这些重复出现的结构抽象为一个“模块”(也叫模型块),每次向模型添加对应的模块来构造模型,从而方便模型的构建。

    因此,接下来将介绍如果构建模块,以及如何利用模块快速搭建复杂模型(以U-Net为例)。

    2.1 U-Net简介

    U-Net是分割(Segmentation)模型的杰作,在以医学影像为代表的诸多领域有广泛的应用。其模型结构如下,通过残差连接结构解决了模型学习中的退化问题,使得神经网络的深度能够不断扩展。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-A1YPREm8-1663511567324)(images/5.2.1unet.png)]

    3. PyTorch修改模型

    除了自己从零开始构建PyTorch模型,有时候也会使用到现场的模型,但是模型的某部分结构不符合我们的要求,这个时候就需要对模型进行修改,常见的修改包括如下:

    • 修改模型若干层
    • 添加额外输入
    • 添加额外输出

    3.1 修改模型层

    import torchvision.models as models
    net = models.resnet50()
    print(net)
    
    # output
    # ResNet(
    #   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    #   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    #   (relu): ReLU(inplace=True)
    # ..............
    #   (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    #   (fc): Linear(in_features=2048, out_features=1000, bias=True)
    # )
    
    # modify the model output so it can be a classifier
    from collections import OrderedDict
    classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)), 
    						  ('relu1', nn.ReLU()),
    						  ('output', nn.Softmax(dim=1))
    						  ]))
    net.fc = classifier
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    上述操作相当于将原本模型的最后一层替换为一个名为classifier的结构。

    3.2 添加额外输入

    3.3 添加额外输出

    4. PyTorch模型保存与读取

    当模型训练好后,我们应该如何保存和读取训练好的模型呢?另外,很多场景下我们会使用多GPU训练,模型会分布于各个GPU上,模型的保存和读取又是怎样的呢?

    本节将介绍:

    • PyTorch模型的存储格式
    • PyTorch如何存储模型
    • 单卡与多卡训练下模型的保存与加载方法

    4.1 模型存储格式

    4.2 模型存储内容

    References

    1. 深入浅出PyTorch第五章PyTorch模型定义
  • 相关阅读:
    随便写一写
    Java项目(二)--Springboot + ElasticSearch 构建博客检索系统(4)- SpringBoot集成ES
    常用锁原理的介绍(上)
    Python虚拟环境的安装
    网页JS自动化脚本(十)新旧字符串关键词检测
    三端sonar记录
    将JMeter测试结果写入Excel【BeanShell取样器】
    探索请求头中的UUID的不同版本:UUID1、UUID3、UUID4和UUID5
    autoware之轮式里程计计算
    基础课5——语音合成技术
  • 原文地址:https://blog.csdn.net/weixin_47802917/article/details/126924503