• pytorch模型保存和加载


            定义2个测试脚本test.py和test2.py,用于测试保存和加载,models文件夹保存模型,整个测试的项目文件结构如下:

    1. E:.
    2. │ test.py
    3. │ test2.py
    4. └─ models
    5. dongtai.pt
    6. dongtai_state_dict.pt
    7. jingtai.pth

    test.py中定义了TheModelClass这个网络结构类,此外写了模型保存和加载的代码,test2.py是想测试在没有定义模型结构的脚本中,是否可以成功加载模型。

    test.py

    1. import torch
    2. import torch.nn as nn
    3. import torch.optim as optim
    4. import torch.nn.functional as F
    5. # 定义模型
    6. class TheModelClass(nn.Module):
    7. def __init__(self):
    8. super(TheModelClass, self).__init__()
    9. self.conv1 = nn.Conv2d(3, 6, 5)
    10. self.pool = nn.MaxPool2d(2, 2)
    11. self.conv2 = nn.Conv2d(6, 16, 5)
    12. self.fc1 = nn.Linear(16 * 4 * 4, 120)
    13. self.fc2 = nn.Linear(120, 84)
    14. self.fc3 = nn.Linear(84, 10)
    15. def forward(self, x):
    16. x = self.pool(F.relu(self.conv1(x)))
    17. x = self.pool(F.relu(self.conv2(x)))
    18. x = x.view(-1, 16 * 4 * 4)
    19. x = F.relu(self.fc1(x))
    20. x = F.relu(self.fc2(x))
    21. x = self.fc3(x)
    22. return x
    23. if __name__ == "__main__":
    24. # 初始化模型
    25. model = TheModelClass()
    26. # 模型保存,方法一动态图
    27. torch.save(model,"models/dongtai.pt")
    28. # 模型保存,方法二动态图
    29. torch.save(model.state_dict(),"models/dongtai_state_dict.pt")
    30. # 模型保存,方法三静态图
    31. x = torch.rand(1,3,30,30) #占位符
    32. trace_model = torch.jit.trace(model,x)
    33. torch.jit.save(trace_model,"models/jingtai.pth")
    34. # 模型加载,带模型结构
    35. model_resume = torch.load("models/dongtai.pt")
    36. # 模型加载,只有权重
    37. weights = torch.load("models/dongtai_state_dict.pt")
    38. model.load_state_dict(weights)
    39. # 直接从静态图中恢复,无需模型结构
    40. model = torch.jit.load("models/jingtai.pth")
    41. x = torch.rand(1,3,30,30)
    42. pred = model(x)
    43. print(pred)

    经过测试,pytorch可以通过三种方法实现模型的保存和加载:

    • 动态图保存模型结构和权重
    • 动态图保存权重
    • 静态图保存权重

    接下来一个个说明这三种方法需要注意的地方。

    一、动态图保存模型结构和权重

    1. # 保存
    2. model = TheModelClass()
    3. # 模型保存
    4. torch.save(model,"models/dongtai.pt")
    5. # 加载,带模型结构
    6. model_resume = torch.load("models/dongtai.pt")

    保存:首先实例化网络对象,然后通过torch.save的方式,将模型结构和权重都序列化保存下来,后缀为pt或者pth都可以,不管保存成哪种后缀,都可以解析。

    加载:首先必须能访问到网络结构的类TheModelClass,然后通过torch.load的方式就可以完整的将模型结构恢复,同时加载好权重。

    这里需要特别注意的点,加载模型的这个文件必须要能找到网络结构的类,不管是在哪里定义网络,都要能导入到当前读取模型的这个文件中做实例化,比如我在test2.py里面导入test.py中的网络类,就可以成功加载,否则会报找不到类的错误。

    test2.py

    1. import torch
    2. from test import TheModelClass # 不导入或同级下找不到会有问题
    3. model = TheModelClass()
    4. model_resume = torch.load("models/dongtai.pt")
    5. model.load_state_dict(model_resume)
    6. model.eval()
    7. print()

    二、动态图保存权重

    1. # 初始化模型
    2. model = TheModelClass()
    3. # 模型保存
    4. torch.save(model.state_dict(),"models/dongtai_state_dict.pth")
    5. # 模型加载,只有权重
    6. weights = torch.load("models/dongtai_state_dict.pth")
    7. model.load_state_dict(weights)

    保存:首先实例化网络对象,然后通过torch.save的方式,只将模型权重序列化保存下来,这种方法不用保存模型结构。

    加载:首先必须能访问到网络结构的类TheModelClass,并实例化,然后通过torch.load的方式就可以将模型权重反序列化取出,然后将其加载进模型对象中。

    注意:必须实例化网络对象,才能加载对应的权重。

    三、静态图保存权重

    1. model = TheModelClass()
    2. # 模型保存,方法三静态图
    3. x = torch.rand(1,3,30,30) #占位符
    4. trace_model = torch.jit.trace(model,x)
    5. torch.jit.save(trace_model,"models/jingtai.pt")
    6. # 直接从静态图中恢复,无需模型结构
    7. model_ji = torch.jit.load("models/jingtai.pt")

    保存:首先实例化网络对象,然后用一个随机的固定尺寸的输入,通过torch.jit.trace,将网络结构前向跑一遍,记录下网络中的节点运行路径,然后通过torch.jit.save将这个运行路径存下来,这种方法会自动记录模型中节点间的数据流动顺序,也就是间接的记录下的模型结构和每个节点的权重。不会单独再保存一个模型类。

    加载:直接用torch.jit.load的方法加载模型即可,因为该模型已经记录了网络中模型节点权重和数据流动的路径,因此只要将数据输入,即可“流过”整个模型,得到最终的输出,不用单独再构造模型类的实例。

    总结

    目前用的最多就是只保存权重的方法(方法二),最后一种用的最少,一般部署的时候也很少用,都是转成onnx再部署。

  • 相关阅读:
    模型调参优化
    Java多线程之线程池(合理分配资源)
    webui automatic1111上可以跑stable diffusion 3的方法
    张益唐的朗道-西格尔零点猜想的论文公布,专家认为该论文尚未完整解决零点猜想
    【MySQL】数据库基础
    容器编排工具的比较:Kubernetes、Docker Swarm、Nomad
    PHP电视剧推荐系统可以用wamp、phpstudy运行定制开发mysql数 据库BS模式
    深度学习之神经网络是如何自行学习的?
    linux中开始mysql中binlog日志
    Python配置镜像源
  • 原文地址:https://blog.csdn.net/sinat_33486980/article/details/127793348