• 现有模型的保存与加载(PyTorch版)


    我们以VGG16网络为例,来说明现有模型的保存与加载操作。

    保存与加载方式均有两种,接下来我们分别来学习这两种方式。注意:保存与加载不在同一个py文件中,我们设定保存操作在save.py文件中,而加载操作在load.py文件中。
    保存模型的两种方式如下代码所示,第一种为既保存模型结构,又保存模型参数;第二种只保存模型参数,并且以字典的形式保存。

    1. import torch
    2. import torchvision
    3. vgg16 = torchvision.models.vgg16(pretrained=False)
    4. # 保存方式1,模型结构+模型参数
    5. torch.save(vgg16, "vgg16_method1.pth") # 保存路径:vgg16_method1.pth
    6. # 保存方式2,模型参数
    7. torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 保存路径:vgg16_method2.pth

    加载模型的两种方式如下代码所示。

    1. import torch
    2. # 方式1 --》保存方式1,加载模型
    3. model1 = torch.load("vgg16_method1.pth")
    4. print(model1)
    5. # 方式2 --》保存方式2,加载模型
    6. model2 = torch.load("vgg16_method2.pth")
    7. print(model2)

    打印结果为:

     model1结果为:

    1. VGG(
    2. (features): Sequential(
    3. (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    4. (1): ReLU(inplace=True)
    5. (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    6. (3): ReLU(inplace=True)
    7. (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    8. (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    9. (6): ReLU(inplace=True)
    10. (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    11. (8): ReLU(inplace=True)
    12. (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    13. (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    14. (11): ReLU(inplace=True)
    15. (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    16. (13): ReLU(inplace=True)
    17. (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    18. (15): ReLU(inplace=True)
    19. (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    20. (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    21. (18): ReLU(inplace=True)
    22. (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23. (20): ReLU(inplace=True)
    24. (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25. (22): ReLU(inplace=True)
    26. (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    27. (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    28. (25): ReLU(inplace=True)
    29. (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    30. (27): ReLU(inplace=True)
    31. (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    32. (29): ReLU(inplace=True)
    33. (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    34. )
    35. (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    36. (classifier): Sequential(
    37. (0): Linear(in_features=25088, out_features=4096, bias=True)
    38. (1): ReLU(inplace=True)
    39. (2): Dropout(p=0.5, inplace=False)
    40. (3): Linear(in_features=4096, out_features=4096, bias=True)
    41. (4): ReLU(inplace=True)
    42. (5): Dropout(p=0.5, inplace=False)
    43. (6): Linear(in_features=4096, out_features=1000, bias=True)
    44. )
    45. )

    上述结果是VGG16网络模型结构以及网络模型参数。 

    model2结果为:(由于结果太多,故只给出部分结果)

    1. OrderedDict([('features.0.weight', tensor([[[[-0.1638, -0.0292, 0.0316],
    2. [-0.0149, 0.0681, 0.0458],
    3. [ 0.0633, -0.0374, -0.0047]],
    4. [[-0.0123, -0.0461, 0.0343],
    5. [ 0.0207, -0.0128, 0.0107],
    6. [-0.0181, 0.0154, 0.0320]],
    7. [[-0.0759, -0.1384, -0.0318],
    8. [ 0.0244, -0.0424, 0.0332],
    9. [-0.0244, 0.0524, 0.1292]]],
    10. ……………………………………

    上述结果是VGG16网络模型参数。 

    那我们要是想用通过保存方式2所保存的模型参数,该如何使用呢?请看下面代码。
    我们先搭建出网络模型结构,随后将保存好的网络模型参数加载到网络模型结构中去。

    1. vgg16 = torchvision.models.vgg16(pretrained=False) # 网络模型结构
    2. vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # 加载保存的网络模型参数
    3. print(vgg16)

    保存方式1有一个小小的陷阱。

    我们通过自己搭建一个网络来说明这个陷阱。
    我们在save.py文件中搭建我们的网络结构,并将其保存。

    1. # 陷阱1
    2. class Tudui(nn.Module):
    3. def __init__(self):
    4. super(Tudui, self).__init__()
    5. self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
    6. def forward(self, x):
    7. x = self.conv1(x)
    8. return x
    9. tudui = Tudui()
    10. torch.save(tudui, "tudui_method1.pth")

    接下来我们按照加载方式1的方法在load.py文件中加载这个模型。

    1. # 陷阱1
    2. model = torch.load("tudui_method1.pth")
    3. print(model)

    打印结果为:

    1. AttributeError: Can't get attribute 'Tudui' on __main__' from 'D:/graduate0/pytorch_practice/model_load.py'>

    我们发现报错了,错误的原因是不能得到Tudui这个属性。

    我们把网络结构添加在load.py文件中。注意:此时不需要创建网络模型,即不用运行tudui=Tudui()这句代码。

    1. from torch import nn
    2. class Tudui(nn.Module):
    3. def __init__(self):
    4. super(Tudui, self).__init__()
    5. self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
    6. def forward(self, x):
    7. x = self.conv1(x)
    8. return x
    9. # tudui = Tudui()
    10. model = torch.load("tudui_method1.pth")
    11. print(model)

    打印结果为: 

    1. Tudui(
    2. (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
    3. )

  • 相关阅读:
    VSCode使用简介
    6聚合根与资源库 #
    【代码随想录】贪心算法刷题
    vue echart详细使用说明
    葡萄糖-聚乙二醇-转铁蛋白|Transferrin-PEG-Glucose|转铁蛋白-PEG-葡萄糖
    JAVA黑马程序员day12--集合进阶(下部--双列集合)
    计算机毕业设计(附源码)python医院门诊分诊系统
    统计子岛屿的数量
    Pycharm 常用快捷键
    推荐一款AI写作大师、问答、绘画工具-「智元兔 AI」
  • 原文地址:https://blog.csdn.net/m0_48241022/article/details/132642411