点赞收藏关注!
如需转载,请注明出处!
torch 的模型搭建,做一下简要的介绍
神经网络由对数据进行操作的层/模块(layers/modules)组成。
torch.nn提供构建网络的所有blocks,在PyTorch中的每个modules都继承了nn.Module,可以构建各种复杂的网络结构。通过nn.Module定义神经网络,使用init初始化,对数据的所有操作都在forward()中实现
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
#前向传播
def forward(self, x)
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
##使用示例
#检测是否有GPU可用,若有可以在GPU上训练模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
model = NeuralNetwork().to(device)
print(model)
X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
为了方便理解,整理了代码中一些函数的意义
如有帮助点赞收藏关注!
如需转载,请注明出处!