• Pytorch-CNN-CIFAR10


    model.py

    import torch.nn as nn
    import torch.nn.functional as F
    import torch.nn.init as init
    class CNN_cls(nn.Module):
        def __init__(self,in_dim):
            super(CNN_cls,self).__init__()
            self.conv1 = nn.Conv2d(in_dim,32,1,1)
            self.pool1 = nn.MaxPool2d(2,2)
            self.conv2 = nn.Conv2d(32,64,1,1)
            self.pool2 = nn.MaxPool2d(2,2)
            self.conv3 = nn.Conv2d(64,128,1,1)
            self.lin1 = nn.Linear(128*8*8,512)
            self.lin2 = nn.Linear(512,64)
            self.lin3 = nn.Linear(64,10)
            self.relu = nn.ReLU()
    
        def forward(self,x):
            x = self.conv1(x)
            x = self.relu(x)
            x = self.pool1(x)
            x = self.conv2(x)
            x = self.relu(x)
            x = self.pool2(x)
            x = self.conv3(x)
            x = self.relu(x)
            x = x.view(-1,128*8*8)
            x = self.lin1(x)
            x = self.relu(x)
            x = self.lin2(x)
            x = self.relu(x)
            x = self.lin3(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    main.py

    import torch
    import torch.nn as nn
    import torchvision
    from torch.utils.data import DataLoader
    import torch.optim as optim
    from model import CNN_cls
    
    
    seed = 42
    torch.manual_seed(seed)
    batch_size_train = 64
    batch_size_test  = 64
    epochs = 10
    learning_rate = 0.01
    momentum = 0.5
    net = CNN_cls(3)
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_test, shuffle=True)
    
    optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    
    print("****************Begin Training****************")
    net.train()
    for epoch in range(epochs):
        run_loss = 0
        correct_num = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            out = net(data)
            _,pred = torch.max(out,dim=1)
            optimizer.zero_grad()
            loss = criterion(out,target)
            loss.backward()
            run_loss += loss
            optimizer.step()
            correct_num  += torch.sum(pred==target)
        print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))
    
    
    
    print("****************Begin Testing****************")
    net.eval()
    test_loss = 0
    test_correct_num = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        test_loss += criterion(out,target)
        test_correct_num  += torch.sum(pred==target)
    print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    运行图

    在这里插入图片描述

  • 相关阅读:
    解决跨域问题
    金九银十面试季在即,Android程序猿如何斩获offer?
    线性回归介绍以及实现
    【Java面试】谈谈你对HashMap的理解(Map接口)
    @Autowired和@Resource的区别
    【踩坑】.NET异步方法不标记async,Task<int> 返回值 return default问题
    包含日志文件
    AI智能安防监控视频播放卡顿的原因排查与分析
    SpringBoot中使用Thymeleaf
    linux部署jar包脚本和注册开机启动
  • 原文地址:https://blog.csdn.net/m0_59741202/article/details/132999538