import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
1
2
3
4
5
# Check that MPS is availableifnot torch.backends.mps.is_available():ifnot torch.backends.mps.is_built():print("MPS not available because the current PyTorch install was not ""built with MPS enabled.")else:print("MPS not available because the current MacOS version is not 12.3+ ""and/or you do not have an MPS-enabled device on this machine.")else:
device = torch.device("mps")
1
2
3
4
5
6
7
8
9
10
classNet(nn.Module):def__init__(self):super().__init__()# 28*28 = 784为输入,100为输出
self.fcl = nn.Linear(784,100)
self.fc2 = nn.Linear(100,10)defforward(self,x):
x = torch.flatten(x,start_dim =1)
x = torch.relu(self.fcl(x))
x = self.fc2(x)return x