模型保存有两种形式,一种是保存模型的 state_dict()
,只是保存模型的参数。那么加载时需要先创建一个模型的实例 model,之后通过 torch.load()
将保存的模型参数加载进来,得到 dict
,再通过 model.load_state_dict(dict)
将模型的参数更新。
另一种是将整个模型保存下来,之后加载的时候只需要通过 torch.load()
将模型加载,即可返回一个加载好的模型。
具体可参考:PyTorch 模型的保存与加载。
torch.save 函数有两种保存方式:一种是保存整个模型,此时模型的 type 应该为继承自 nn.Module 的类;另一种是仅保存模型的参数,此时模型的type应该为有序字典即类 OrderedDict。
我在通过两个神经元的极简模型,清晰透视 Pytorch 工作原理中构建了两个神经元的网络。以这个为例,我在模型训练完成后保存成 pt 文件,然后再加载测试。代码如下:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
x = self.fc2(x)
return x
net = Net()
x = torch.linspace(0, 1, 10).reshape(10, 1)
y = x*x - 0.5*x + 1.5625
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr = 0.001)
for n in range(0, 100000):
optimizer.zero_grad()
loss = sum(abs(net(x) - y))
loss.backward()
optimizer.step()
if n % 1000 == 0:
print(n, loss)
print('Finished Training!')
# 保存模型
torch.save(net, "/home/yeping/mynet.pt")
# 加载模型
net2 = torch.load( "/home/yeping/mynet.pt")
# 测试加载的模型
import matplotlib.pyplot as plt
plt.plot(x, y, "k*")
z=[]
x = torch.linspace(0,1,100).reshape(100,1)
for xx in x:
zz = net2(xx)
z.append(zz)
plt.plot(x, z, "b-")
最后显示图片:
实际上,仅写测试代码也是可以。这说明 pt 文件包含了模型的结构定义和权重信息。
import torch
import torch.nn as nn
# 加载模型
net2 = torch.load( "/home/yeping/mynet.pt")
# 测试加载的模型
import matplotlib.pyplot as plt
x = torch.linspace(0, 1, 10).reshape(10, 1)
y = x*x - 0.5*x + 1.5625
plt.plot(x, y, "k*")
z=[]
x = torch.linspace(0,1,100).reshape(100,1)
for xx in x:
zz = net2(xx)
z.append(zz)
plt.plot(x, z, "b-")
运行后,结果与上面一样。
下面的例子中仅仅保存模型参数:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 1)
def forward(self, x):
x = torch.sigmoid(self.fc1(x))
x = self.fc2(x)
return x
net = Net()
x = torch.linspace(0, 1, 10).reshape(10, 1)
y = x*x - 0.5*x + 1.5625
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr = 0.001)
for n in range(0, 100000):
optimizer.zero_grad()
loss = sum(abs(net(x) - y))
loss.backward()
optimizer.step()
if n % 1000 == 0:
print(n, loss)
print('Finished Training!')
# 保存模型
torch.save(net.state_dict(), "/home/yeping/mynet-wts.pt")
# 加载模型
net2 = Net()
net2.load_state_dict(torch.load("/home/yeping/mynet-wts.pt"))
# 测试加载的模型
import matplotlib.pyplot as plt
plt.plot(x, y, "k*")
z=[]
x = torch.linspace(0,1,100).reshape(100,1)
for xx in x:
zz = net2(xx)
z.append(zz)
plt.plot(x, z, "b-")
查了看了一下 net.state_dict()
的内容,整理如下:
OrderedDict
(
[
(
'fc1.weight',
tensor
(
[
[-3.5805],
[-1.4185]
]
)
),
(
'fc1.bias',
tensor
(
[
3.4771,
-0.8718
]
)
),
(
'fc2.weight',
tensor
(
[
[
-1.7204,
1.7042
]
]
)
),
(
'fc2.bias',
tensor(
[2.7291]
)
)
]
)