• 机器学习基础-手写数字识别


    1. 手写数字识别,计算机视觉领域的Hello World
    2. 利用MNIST数据集,55000训练集,5000验证集。
    3. Pytorch实现神经网络手写数字识别
    4. 感知机与神经元、权重和偏置、神经网络、输入层、隐藏层、输出层
    5. mac gpu的使用
    6. 本节就是对Pytorch可以做的事情有个直观的理解,先理解表面,把大概知识打通,然后再研究细节的东西
    import torch
    import torch.nn as nn
    import torchvision
    import torchvision.transforms as transforms
    import torch.optim as optim
    
    • 1
    • 2
    • 3
    • 4
    • 5
    # Check that MPS is available
    if not torch.backends.mps.is_available():
        if not torch.backends.mps.is_built():
            print("MPS not available because the current PyTorch install was not "
                  "built with MPS enabled.")
        else:
            print("MPS not available because the current MacOS version is not 12.3+ "
                  "and/or you do not have an MPS-enabled device on this machine.")
    else:
        device = torch.device("mps")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            # 28*28 = 784为输入,100为输出
            self.fcl = nn.Linear(784,100)
            self.fc2 = nn.Linear(100,10)
            
        def forward(self,x):
            x = torch.flatten(x,start_dim = 1)
            x = torch.relu(self.fcl(x))
            x = self.fc2(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    # 当前模型对数据集学几次
    max_epochs = 5
    # 每次训练模型对多少张图片进行训练
    batch_size = 16
    
    # data
    # ToTensor 把当前数据类型转换为 Tensor
    # Compose是组合多个转换操作的类
    transform = transforms.Compose([transforms.ToTensor()])
    
    # 55000
    trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True)
    testset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
    test_loader = torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    # net init
    net = Net()
    net.to(device)
    
    # nn.MSE
    loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),lr=0.0001)
    
    def train():
        acc_num=0
        for epoch in range(max_epochs):
            for i,(data,label) in enumerate(train_loader):
                data = data.to(device)
                label = label.to(device)
                optimizer.zero_grad()
                output = net(data)
                Loss = loss(output,label)
                Loss.backward()
                optimizer.step()
                
                pred_class = torch.max(output,dim=1)[1]
                acc_num += torch.eq(pred_class,label.to(device)).sum().item()
                train_acc = acc_num / len(trainset)
            net.eval()
            acc_num = 0.0
            best_acc = 0
            with torch.no_grad():
                for val_data in test_loader:
                    val_image,val_label = val_data
                    output = net(val_image.to(device))
                    predict_y = torch.max(output , dim=1)[1]
                    acc_num += torch.eq(predict_y,val_label.to(device)).sum().item()
                val_acc = acc_num/len(testset)
                print(train_acc,val_acc)
                if val_acc > best_acc:
                    torch.save(net.state_dict(),'./minst.pth')
                    best_acc = val_acc
                acc_num = 0
                train_acc = 0
                test_acc = 0
            print('done')
    
    train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    0.1348 0.3007
    done
    0.4361 0.5548
    done
    0.5870666666666666 0.6335
    done
    0.6435333333333333 0.672
    done
    0.67915 0.7011
    done
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
  • 相关阅读:
    微信小程序rich-text 文本首行缩进和图片居中和富文本rich-text 解析多个空格不成功 &nbsp
    1.mysql--常用sql(2)
    c#语法详解
    EtherCAT从站EEPROM分类附加信息详解:TXPDO(输出过程数据对象)
    Docker入门Dockerfile详解及镜像创建
    GitHub上整理的一些实用的工具
    线性代数应用基础补充2
    Kamailio Debian安装
    荣耀手机如何开启地震预警功能
    Kafka三种认证模式,Kafka 安全认证及权限控制详细配置与搭建
  • 原文地址:https://blog.csdn.net/qq_61735602/article/details/133637393