加载和保存张量
# 1个张量
import torch
from torch import nn
from torch.nn import functional as F
x=torch.arange(4)
torch.save(x,'x-file')
x2=torch.load('x-file')
x2
tensor([0, 1, 2, 3])
# 张量list
y=torch.zeros(4)
torch.save([x,y],'x-files')
x2,y2=torch.load('x-files')
(x2,y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
# 张量字典
mydict={'x':x,'y':y}
torch.save(mydict,'mydict')
mydict2=torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
加载和保存模型参数(存储权重参数即可)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden=nn.Linear(20,256)
self.output=nn.Linear(256,10)
def forward(self,x):
return self.output(F.relu(self.hidden(X)))
net=MLP()
X=torch.randn(size=(2,20))
Y=net(X)
# 存储权重参数
torch.save(net.state_dict(),'mlp.params')
# 加载权重参数
clone=MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
# 使用
Y_clone=clone(X)
Y_clone==Y
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
query
将类别变量转伪变量(one-hot编码特征值),内存炸掉了
net(X)会自动调用forward()方法
__call__()中调用了forward()方法。__call__(), __init__()这种格式的方法都是magic method,他们会被python自动调用kaiming初始化和xavier初始化