PyTorch提供了两种主要的模型保存和加载机制,一种是基于Python的序列化,另一种是TorchScript。
torch.save(model.state_dict(), 'model_path.pth'),它保存了模型的权重和参数,但不保存模型的结构。model.load_state_dict(torch.load('model_path.pth'))来加载权重。保存模型:
torch.save(model.state_dict(), 'model_weights.pth')
加载模型:
model = ModelClass()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 设置为评估模式
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
torch.save(model.state_dict(), 'simple_model_weights.pth')
# 在其他地方或时间
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('simple_model_weights.pth'))
loaded_model.eval()
torch.jit.trace方法。这涉及到通过模型运行一个输入示例,从而跟踪模型的执行路径。torch.jit.script方法。这转化Python代码到TorchScript,允许更复杂的模型和控制流。torch.jit.save(traced_model, 'model_path.pt')。torch.jit.load('model_path.pt')。注意,加载不需要原始的模型类定义。Tracing方法:
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_model.pt')
Scripting方法:
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')
加载模型:
loaded_model = torch.jit.load('model_path.pt')
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
# Tracing
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_simple_model.pt')
# Scripting
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_simple_model.pt')
# 加载模型
loaded_model = torch.jit.load('traced_simple_model.pt')