本篇介绍模型模型的参数,模型推理和使用,保存加载。
在本单元中,我们将了解如何加载模型及其持久参数状态和推理模型预测。为了加载模型,我们将定义模型类,其中包含用于训练模型的神经网络的状态和参数。
%matplotlib inline import torch import onnxruntime from torch import nn import torch.onnx as onnx import torchvision.models as models from torchvision import datasets from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10), nn.ReLU() ) def forward(self, x): x = self.flatten(x) l