• 《动手学深度学习 Pytorch版》 5.5 读写文件


    5.5.1 加载和保存

    import torch
    from torch import nn
    from torch.nn import functional as F
    
    x = torch.arange(4)
    torch.save(x, 'x-file')  # 使用 save 保存
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    x2 = torch.load('x-file')  # 使用 load 读回内存
    x2
    
    • 1
    • 2
    tensor([0, 1, 2, 3])
    
    • 1
    y = torch.zeros(4)
    torch.save([x, y],'x-files')  # 也可以存储张量列表
    x2, y2 = torch.load('x-files')
    (x2, y2)
    
    • 1
    • 2
    • 3
    • 4
    (tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
    
    • 1
    mydict = {'x': x, 'y': y}  # 存储从字符串映射到张量的字典
    torch.save(mydict, 'mydict')
    mydict2 = torch.load('mydict')
    mydict2
    
    • 1
    • 2
    • 3
    • 4
    {'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
    
    • 1

    5.5.2 加载和保存模型参数

    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.hidden = nn.Linear(20, 256)
            self.output = nn.Linear(256, 10)
    
        def forward(self, x):
            return self.output(F.relu(self.hidden(x)))
    
    net = MLP()
    X = torch.randn(size=(2, 20))
    Y = net(X)
    
    torch.save(net.state_dict(), 'mlp.params')  # 保存模型参数
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    clone = MLP()
    clone.load_state_dict(torch.load('mlp.params'))  # 加载文件中存储的参数
    
    Y_clone = clone(X)  # 参数一致则计算结果也应相同
    
    clone.eval(), Y_clone == Y
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    (MLP(
       (hidden): Linear(in_features=20, out_features=256, bias=True)
       (output): Linear(in_features=256, out_features=10, bias=True)
     ),
     tensor([[True, True, True, True, True, True, True, True, True, True],
             [True, True, True, True, True, True, True, True, True, True]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    练习

    (1)即使不需要将经过训练的模型部署到不同的设备上,保存的模型参数还有什么实际的好处?

    用作备份或备为下一步处理均可。


    (2)假设我们只想复用网络的一部分,已将其合并到不同的网络架构中。例如想在一个新的网络中使用之前网络的前两层,该怎么做?

    torch.save(net.hidden.state_dict(), 'mlp.hidden.params')  # 需要哪里存哪里
    clone = MLP()
    clone.hidden.load_state_dict(torch.load('mlp.hidden.params'))  # 需要哪里加载哪里
    
    clone.eval(), clone.hidden.weight == net.hidden.weight
    
    • 1
    • 2
    • 3
    • 4
    • 5
    (MLP(
       (hidden): Linear(in_features=20, out_features=256, bias=True)
       (output): Linear(in_features=256, out_features=10, bias=True)
     ),
     tensor([[True, True, True,  ..., True, True, True],
             [True, True, True,  ..., True, True, True],
             [True, True, True,  ..., True, True, True],
             ...,
             [True, True, True,  ..., True, True, True],
             [True, True, True,  ..., True, True, True],
             [True, True, True,  ..., True, True, True]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    (3)如何同时保存网络架构和参数?需要对架构加上什么限制?

    net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
    torch.save(net, 'net')  # pytorch 本身就支持保存模型
    net_new = torch.load('net')
    net_new
    
    • 1
    • 2
    • 3
    • 4
    Sequential(
      (0): Linear(in_features=20, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=10, bias=True)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    克拉默法则
    jdk-8u371-linux-x64.tar.gz jdk-8u371-windows-x64.exe 【jdk-8u371】 全平台下载
    python连接mysql数据库报错pymysql.err.OperationalError
    ES6新特性之箭头函数
    【操作系统笔记】高速缓存
    【大数据】Flink 内存管理(二):JobManager 内存分配(含实际计算案例)
    使用 Stable Diffusion Img2Img 生成、放大、模糊和增强
    如何恢复edge的自动翻译功能
    【云原生之kubernetes实战】在k8s环境下部署WBO在线协作白板
    权限提升数据库(基于MySQL的UDF,MOF,启动项提权)
  • 原文地址:https://blog.csdn.net/qq_43941037/article/details/132938791