• pytorch学习3(pytorch手写数字识别练习)


    网络模型

    设置三层网络,一般最后一层激活函数不选择relu
    在这里插入图片描述

    任务步骤

    手写数字识别任务共有四个步骤:
    1、数据加载--Load Data
    2、构建网络--Build Model
    3、训练--Train
    4、测试--Test
    
    • 1
    • 2
    • 3
    • 4
    • 5

    实战

    1、导入各种需要的包

    import torch
    from torch import nn
    from torch.nn import functional as F
    from torch import optim
    
    import torchvision
    
    from matplotlib import pyplot as plt
    
    from minist_utils import plot_image, plot_curve, one_hot ##自写文件
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    minist_utils:
    在这里插入图片描述
    在这里插入图片描述在这里插入图片描述

    2、加载数据

    batch_size = 512
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081, ))
                                   ])),
        batch_size=batch_size, shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081, ))
                                   ])),
        batch_size=batch_size, shuffle=False
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    取一些样本看数据的shape以及图片内容

    x, y = next(iter(train_loader))
    print(x.shape, y.shape, x.min(), x.max())
    plot_image(x, y, 'image sample')
    
    • 1
    • 2
    • 3

    在这里插入图片描述在这里插入图片描述

    注:经过load加载处理后的数据集包含x(图像信息)和y(标签信息)
    next(iter())的用法是取一组样本,重复运行可以依次顺序取样,直到样本被取完
    可在csdn自行搜索学习了解
    
    • 1
    • 2
    • 3

    3、网络构建

    按之前设想的三层线性模型嵌套的思想搭建模型,为了模型简单,第三层不加激活函数。

    class Net(nn.Module):
    
        def __init__(self):
            super(Net, self).__init__()
    
            # xw+b
            self.fc1 = nn.Linear(28*28, 256) #输入特征数,输出特征数
            self.fc2 = nn.Linear(256, 64)  #256,64是根据经验判断
            self.fc3 = nn.Linear(64, 10)  #最开始的28*28和输出的10是一定的
    
        def forward(self, x):
            # x: [b, 1, 28, 28]
            # h1 = relu(xw1 + b1)
            x = F.relu(self.fc1(x)) #输入x后第一次线性模型得到H1作第二层输入
            # h2 = relu(h1w2 + b2)
            x = F.relu(self.fc2(x)) #输入H1得到H2作第三层输入
            # h3 = h2w3 + b3
            x = self.fc3(x)	#输入H3得到最终结果,维度为10
    
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    4、模型训练

    net = Net()
    
    # [w1, b1, w2, b2, w3, b3]
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    
    train_loss = []
    
    for epoch in range(3):
    
        for batch_idx, (x, y) in enumerate(train_loader):
    
            # x: [b, 1, 28, 28], y: [512]
            # [b, 1, 28, 28] => [b, feature] 全连接层只能接受这样的数据
            x = x.view(x.size(0), 28*28)
            # => [b, 10]
            out = net(x)
            # [b, 10]
            y_onehot = one_hot(y)
            # loss = mse(out, y_onehot)
            loss = F.mse_loss(out, y_onehot)
    
            optimizer.zero_grad()
            loss.backward() # 梯度计算过程
            # w` = w - lr * grad
            optimizer.step() # 优化更新w,b
    
            train_loss.append(loss.item())
    
            if batch_idx % 10 == 0:
                print(epoch, batch_idx, loss.item())
    
    plot_curve(train_loss)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    在这里插入图片描述

    5、测试

    1、计算准确率acc

    total_correct = 0
    for x, y in test_loader:
        x = x.view(x.size(0), 28*28)
        out = net(x)
        # out: [b, 10] => pred: [b]
        pred = out.argmax(dim=1)
        correct = pred.eq(y).sum().float().item()
        total_correct += correct
    
    total_num = len(test_loader.dataset)
    acc = total_correct / total_num
    print(("acc:", acc))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这里插入图片描述
    2、展示部分测试样本原图以及预测标签结果

    x, y =next(iter(test_loader))
    out = net(x.view(x.size(0), 28*28))
    pred = out.argmax(dim=1)
    plot_image(x, pred, 'test')
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

  • 相关阅读:
    【路径规划】基于遗传算法求解立体仓库出入库路径优化问题含Matlab代码
    SpringBoot 日志
    [2023年度回顾总结]凡是过往,皆为序章
    软件设计师 程序设计语言
    大道至简的架构设计思想之:封装(C系架构设计法,sishuok)
    spring5.0 源码解析 createBeanInstance 09
    简单工厂模式
    kafka基本概念
    JS 模块化- 05 ES Module & 4 大规范总结
    【云原生】Kubernetes临时容器
  • 原文地址:https://blog.csdn.net/qq_52015311/article/details/133087431