• Pytorch入门(4)—— Tensor和Module的保存与加载


    • 参考:动手学深度学习
    • 注意:由于本文是jupyter文档转换来的,代码不一定可以直接运行,有些注释是jupyter给出的交互结果,而非运行结果!!

    1. 读写 Tensor

    • pytorch 提供了 torch.save 函数和 torch.load 函数,分别用于存储和读取 Tensor,其中

      1. torch.save 使用 Python 的 pickle 持续化模块将对象进行序列化并保存到本地磁盘,torch.save 可以保存各种对象,包括模型、张量和字典等
      2. torch.load 使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存变量
    • 下面创建 Tensor 变量 x,并将保存为本地文件 x.pt,然后再读取回来

      import torch
      from torch import nn
      
      x = torch.ones(3)
      torch.save(x, 'x.pt')
      
      x2 = torch.load('x.pt')
      print(x2) # tensor([1., 1., 1.])
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
    • 类似地,存储Tensor列表和字典并读回内存

      x = torch.ones(2)
      y = torch.zeros(4)
      torch.save([x, y], 'xy_list.pt')
      xy_list = torch.load('xy_list.pt')
      print(xy_list)
      
      torch.save({'x': x, 'y': y}, 'xy_dict.pt')
      xy_dict = torch.load('xy_dict.pt')
      print(xy_dict)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      [tensor([1., 1.]), tensor([0., 0., 0., 0.])]
      {'x': tensor([1., 1.]), 'y': tensor([0., 0., 0., 0.])}
      
      • 1
      • 2

    2. 读写 Module

    2.1 state_dict

    • 保存/加载 Module 的一个思路是保存和加载其所有参数。PyTorch 中 Module 的可学习参数包括权重 weight 和偏置 bias,它们可以通过 .parameters().named_parameter() 方法访问

    • 调用 Module 的 .state_dict() 方法,返回一个从参数名称(“layer名.weight” 或 “layer名.bias”)映射到到参数 Tesnor 的字典对象,其中包含了模型的所有可学习参数

      class MLP(nn.Module):
          def __init__(self):
              super(MLP, self).__init__()
              self.hidden = nn.Linear(3, 2)
              self.act = nn.ReLU()
              self.output = nn.Linear(2, 1)
      
          def forward(self, x):
              a = self.act(self.hidden(x))
              return self.output(a)
      
      net = MLP()
      net.state_dict()
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      OrderedDict([('hidden.weight',
                    tensor([[-0.0749,  0.2594, -0.5274],
                            [ 0.3983, -0.2925,  0.5102]])),
                   ('hidden.bias', tensor([-0.1095, -0.4895])),
                   ('output.weight', tensor([[ 0.2644, -0.3989]])),
                   ('output.bias', tensor([-0.2737]))])
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6

      可见,只有具有可学习参数的层(卷积层、线性层等)才有 state_dict 中的条目;另外,优化器(optim)也有一个 state_dict,其中包含关于优化器状态以及所使用的超参数的信息

      optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
      optimizer.state_dict()
      
      • 1
      • 2
      {'state': {},
       'param_groups': [{'lr': 0.001,
         'momentum': 0.9,
         'dampening': 0,
         'weight_decay': 0,
         'nesterov': False,
         'params': [0, 1, 2, 3]}]}
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7

    2.2 保存和加载模型

    • PyTorch 中保存和加载训练模型有两种常见的方法:
      1. 仅保存和加载模型参数(state_dict),这是推荐方式
      2. 保存和加载整个模型

    2.2.1 保存和加载 state_dict(推荐)

    • 通过保存和加载参数来实现模型的存取,相比直接保存整个 Module 对象轻量很多

      1. 保存时,使用 torch.save 保存模型的 state_dict 字典实例
      2. 加载时,先使用 torch.load 加载模型的 state_dict,然后实例化一个所需类型的 Module,最后调用 Moudule 的 .load_state_dict 方法加载参数
    • 给出一个读写多层感知机模型的示例

      # 原始模型
      print('原始模型')
      net = MLP()
      print(net,'\n')
      for name, param in net.named_parameters():
          print(name, param)
      
      # 保存 & 加载
      print('\n\n保存再加载模型')
      torch.save(net.state_dict(), 'mlp.pt') # 推荐的文件后缀名是pt或pth
      model = MLP()
      model.load_state_dict(torch.load('mlp.pt'))
      
      print(model,'\n')
      for name, param in model.named_parameters():
          print(name, param)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      原始模型
      MLP(
        (hidden): Linear(in_features=3, out_features=2, bias=True)
        (act): ReLU()
        (output): Linear(in_features=2, out_features=1, bias=True)
      ) 
      
      hidden.weight Parameter containing:
      tensor([[ 0.3508,  0.5170,  0.2746],
              [-0.0571, -0.1772,  0.2815]], requires_grad=True)
      hidden.bias Parameter containing:
      tensor([0.1041, 0.1207], requires_grad=True)
      output.weight Parameter containing:
      tensor([[-0.4299, -0.2678]], requires_grad=True)
      output.bias Parameter containing:
      tensor([-0.3851], requires_grad=True)
      
      
      保存再加载模型
      MLP(
        (hidden): Linear(in_features=3, out_features=2, bias=True)
        (act): ReLU()
        (output): Linear(in_features=2, out_features=1, bias=True)
      ) 
      
      hidden.weight Parameter containing:
      tensor([[ 0.3508,  0.5170,  0.2746],
              [-0.0571, -0.1772,  0.2815]], requires_grad=True)
      hidden.bias Parameter containing:
      tensor([0.1041, 0.1207], requires_grad=True)
      output.weight Parameter containing:
      tensor([[-0.4299, -0.2678]], requires_grad=True)
      output.bias Parameter containing:
      tensor([-0.3851], requires_grad=True)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34

    2.2.2 保存和加载整个模型

    • 第 1 节说明了 torch.save 可以保存各种对象,包括模型、张量和字典等 ,所以也可以直接对 Module 实例使用 torch.savetorch.load 进行保存加载

      # 原始模型
      print('原始模型')
      net = MLP()
      print(net,'\n')
      for name, param in net.named_parameters():
          print(name, param)
      
      # 保存 & 加载
      print('\n\n保存再加载模型')
      torch.save(net, 'mlp.pt') # 推荐的文件后缀名是pt或pth
      model = torch.load('mlp.pt')
      
      print(model,'\n')
      for name, param in model.named_parameters():
          print(name, param)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      原始模型
      MLP(
        (hidden): Linear(in_features=3, out_features=2, bias=True)
        (act): ReLU()
        (output): Linear(in_features=2, out_features=1, bias=True)
      ) 
      
      hidden.weight Parameter containing:
      tensor([[ 0.1274, -0.4678, -0.5102],
              [ 0.3547,  0.3654, -0.2288]], requires_grad=True)
      hidden.bias Parameter containing:
      tensor([-0.3877,  0.0637], requires_grad=True)
      output.weight Parameter containing:
      tensor([[0.5678, 0.6307]], requires_grad=True)
      output.bias Parameter containing:
      tensor([0.0517], requires_grad=True)
      
      
      保存再加载模型
      MLP(
        (hidden): Linear(in_features=3, out_features=2, bias=True)
        (act): ReLU()
        (output): Linear(in_features=2, out_features=1, bias=True)
      ) 
      
      hidden.weight Parameter containing:
      tensor([[ 0.1274, -0.4678, -0.5102],
              [ 0.3547,  0.3654, -0.2288]], requires_grad=True)
      hidden.bias Parameter containing:
      tensor([-0.3877,  0.0637], requires_grad=True)
      output.weight Parameter containing:
      tensor([[0.5678, 0.6307]], requires_grad=True)
      output.bias Parameter containing:
      tensor([0.0517], requires_grad=True)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
  • 相关阅读:
    设计模式-命令模式
    人工神经网络模型有哪些,神经网络分类四种模型
    【Linux】ls命令
    QT 布局垂直居中【比例放大缩小仍然居中】
    HTTP协议
    [ESP32][esp-idf] AP+STA实现无线桥接(中转wifi信号)
    CentOS系统安装vsftpd
    入侵检测---IDS
    MQ系列14:MQ如何做到消息延时处理
    Java中方法的注意事项
  • 原文地址:https://blog.csdn.net/wxc971231/article/details/126886075