• 【pytorch】LeNet-5 手写数字识别MNIST


    1、LeNet5 模型

    在这里插入图片描述
    在这里插入图片描述

    模型特点:每个卷积层包含3个部分:卷积、池化(Average Pooling)、非线性激活函数(Tanh)

    class LeNet5(nn.Module):
        """ 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
        def __init__(self, in_channel, output):
            super(LeNet5, self).__init__()
            self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2),   # (6, 28, 28)
                                        nn.Tanh(),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (6, 14, 14))
    
            self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),  # (16, 10, 10)
                                        nn.Tanh(),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (16, 5, 5)
    
            self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)  # (120, 1, 1)
    
            self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
                                        nn.Tanh(),
                                        nn.Linear(in_features=84, out_features=output))
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = torch.flatten(input=x, start_dim=1)
            x = self.layer4(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    2、训练模型

    import numpy as np
    import torch
    import torch.nn as nn
    from torchvision.datasets import mnist
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    import torch.optim as optim
    import matplotlib.pyplot as plt
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    
    train_batch_size = 12
    test_batch_size = 48
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    # 下载数据 & 导入数据
    train_set = mnist.MNIST("./mnist_data", train=True, download=True, transform=transform)
    test_set = mnist.MNIST("./mnist_data", train=False, transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=train_batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False)
    
    # # 抽样查看图片
    # examples = enumerate(train_loader)
    # batch_index, (example_data, example_label) = next(examples)
    # print(type(example_data))   # 
    # print(example_data.shape)   # torch.Size([64, 1, 28, 28])
    
    # for i in range(6):
    #     plt.subplot(2, 3, i+1)
    #     plt.tight_layout()
    #     plt.imshow(example_data[i][0], cmap='gray')
    #     plt.title("Ground Truth: {}".format(example_label[i]))
    #     plt.xticks([])
    #     plt.yticks([])
    # plt.show()
    
    
    class LeNet5(nn.Module):
        """ 使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起 """
        def __init__(self, in_channel, output):
            super(LeNet5, self).__init__()
            self.layer1 = nn.Sequential(nn.Conv2d(in_channels=in_channel, out_channels=6, kernel_size=5, stride=1, padding=2),   # (6, 28, 28)
                                        nn.Tanh(),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (6, 14, 14))
    
            self.layer2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),  # (16, 10, 10)
                                        nn.Tanh(),
                                        nn.AvgPool2d(kernel_size=2, stride=2, padding=0))   # (16, 5, 5)
    
            self.layer3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)  # (120, 1, 1)
    
            self.layer4 = nn.Sequential(nn.Linear(in_features=120, out_features=84),
                                        nn.Tanh(),
                                        nn.Linear(in_features=84, out_features=output))
    
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = torch.flatten(input=x, start_dim=1)
            x = self.layer4(x)
            return x
    
    model = LeNet5(1, 10)
    model.to(device)
    
    lr = 0.01
    num_epoches = 20
    momentum = 0.8
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    
    
    eval_losses = []
    eval_acces = []
    
    for epoch in range(num_epoches):
    
        if epoch % 5 == 0:
            optimizer.param_groups[0]['lr'] *= 0.1
    
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            predict = model(imgs)
            loss = criterion(predict, labels)
    
            # back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    
        eval_loss = 0
        eval_acc = 0
        model.eval()
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            predict = model(imgs)
            loss = criterion(predict, labels)
    
            # record loss
            eval_loss += loss.item()
    
            # record accurate rate
            result = torch.argmax(predict, axis=1)
            acc_num = (result == labels).sum().item()
            acc_rate = acc_num / imgs.shape[0]
            eval_acc += acc_rate
    
        eval_losses.append(eval_loss / len(test_loader))
        eval_acces.append(eval_acc / len(test_loader))
    
        print('epoch: {}'.format(epoch))
        print('loss: {}'.format(eval_loss / len(test_loader)))
        print('accurate rate: {}'.format(eval_acc / len(test_loader)))
        print('\n')
    
    plt.title('evaluation loss')
    plt.plot(np.arange(len(eval_losses)), eval_losses)
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124

    3、输出

    请添加图片描述
    epoch: 0
    loss: 0.20932436712157498
    accurate rate: 0.9417862838915463

    epoch: 1
    loss: 0.1124003769263946
    accurate rate: 0.9681020733652314

    epoch: 2
    loss: 0.0809573416740736
    accurate rate: 0.9753787878787872

    epoch: 3
    loss: 0.07089491755452061
    accurate rate: 0.9779704944178623

    epoch: 4
    loss: 0.05831286043338656
    accurate rate: 0.9821570972886757

    epoch: 5
    loss: 0.05560500273351785
    accurate rate: 0.9828548644338115

    epoch: 6
    loss: 0.0542455422597309
    accurate rate: 0.9835526315789472

    epoch: 7
    loss: 0.05367041283908732
    accurate rate: 0.9838516746411479

    epoch: 8
    loss: 0.05298826666370605
    accurate rate: 0.9838516746411481

    epoch: 9
    loss: 0.05252152112530963
    accurate rate: 0.9836523125996807

    epoch: 10
    loss: 0.05247020455629846
    accurate rate: 0.9836523125996808

    epoch: 11
    loss: 0.05242454297127621
    accurate rate: 0.9837519936204145

    epoch: 12
    loss: 0.05237526405083559
    accurate rate: 0.9838516746411481

    epoch: 13
    loss: 0.05233189105290171
    accurate rate: 0.9839513556618819

    epoch: 14
    loss: 0.05222674906053291
    accurate rate: 0.9837519936204145

    epoch: 15
    loss: 0.052228276117072044
    accurate rate: 0.9837519936204145

    epoch: 16
    loss: 0.05222897543727852
    accurate rate: 0.9837519936204145

    epoch: 17
    loss: 0.05222897782574216
    accurate rate: 0.9838516746411481

    epoch: 18
    loss: 0.05222847037079731
    accurate rate: 0.9838516746411481

    epoch: 19
    loss: 0.05222745426054866
    accurate rate: 0.9838516746411481

    请添加图片描述

  • 相关阅读:
    智慧巡查平台(Ionic/Vite/Vue3 移动端) 问题记录
    国内原汁原味的免费sd训练工具--哩布哩布AI
    0基础学习Elasticsearch-使用Java操作ES
    项目实战:Qt监测操作系统物理网卡通断v1.1.0(支持windows、linux、国产麒麟系统)
    k8s上线Java项目文件导出异常总结
    全球名校AI课程库(34)| 辛辛那提大学 · 微积分Ⅰ课程『MATH100 · Calculus I』
    工控安全PLC固件逆向三
    申请专利的好处!这份清单告诉你,为什么要申请专利?
    【Jlink & C#】通过C#实现Jlink RTT上位机的功能
    免编程经验,搭建宠物店小程序轻松实现
  • 原文地址:https://blog.csdn.net/weixin_37804469/article/details/126486375