• Pytorch-CNN-Mnist


    model.py

    import torch.nn as nn
    class CNN_cls(nn.Module):
        def __init__(self,in_dim=28*28):
            super(CNN_cls,self).__init__()
            self.conv1 = nn.Conv2d(1,32,1,1)
            self.pool1 = nn.MaxPool2d(2,2)
            self.conv2 = nn.Conv2d(32,64,1,1)
            self.pool2 = nn.MaxPool2d(2,2)
            self.conv3 = nn.Conv2d(64,128,1,1)
            self.lin1 = nn.Linear(128*7*7,512)
            self.lin2 = nn.Linear(512,64)
            self.lin3 = nn.Linear(64,10)
            self.relu = nn.ReLU()
    
        def forward(self,x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.pool1(x)
            x = self.conv2(x)
            x = self.relu(x)
            x = self.pool2(x)
            x = self.conv3(x)
            x = self.relu(x)
            x = x.view(-1,128*7*7)
            x = self.lin1(x)
            x = self.relu(x)
            x = self.lin2(x)
            x = self.relu(x)
            x = self.lin3(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    main.py

    import torch
    import torch.nn as nn
    import torchvision
    from torch.utils.data import DataLoader
    import torch.optim as optim
    from model import CNN_cls
    
    
    seed = 42
    torch.manual_seed(seed)
    batch_size_train = 64
    batch_size_test  = 64
    epochs = 10
    learning_rate = 0.01
    momentum = 0.5
    net = CNN_cls()
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_test, shuffle=True)
    
    optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    
    print("****************Begin Training****************")
    net.train()
    for epoch in range(epochs):
        run_loss = 0
        correct_num = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            out = net(data)
            _,pred = torch.max(out,dim=1)
            optimizer.zero_grad()
            loss = criterion(out,target)
            loss.backward()
            run_loss += loss
            optimizer.step()
            correct_num  += torch.sum(pred==target)
        print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))
    
    
    
    print("****************Begin Testing****************")
    net.eval()
    test_loss = 0
    test_correct_num = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        test_loss += criterion(out,target)
        test_correct_num  += torch.sum(pred==target)
    print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    网络设置

    在CNN_cls里面查看。

    注意事项及改进

    1.注意第一个输入通道是1,因为是灰度图像。
    2.可以考虑加入GPU
    
    • 1
    • 2

    运行截图

    在这里插入图片描述

  • 相关阅读:
    四. node小工具(nodemon/supervisor)
    AIGC是不是有点虎头蛇尾
    springboot 等待批量异步任务完成
    【腾讯云原生降本增效大讲堂】京东云原生大规模实践之路
    Java【多线程】Callable 是什么, 如何使用并理解 Cllable, 和 Runnable 有什么区别?
    _linux 进程间通信(命名管道)
    ffmpeg sdk 视频合成
    web大学生个人网站作业模板——上海旅游景点介绍网页代码 家乡旅游网页制作模板 大学生静态HTML网页源码...
    muduo库的安装
    Flink1.15源码阅读——执行图executiongraph
  • 原文地址:https://blog.csdn.net/m0_59741202/article/details/132954316