• Pytorch-MLP-CIFAR10


    model.py

    import torch.nn as nn
    import torch.nn.functional as F
    import torch.nn.init as init
    
    class MLP_cls(nn.Module):
        def __init__(self,in_dim=3*32*32):
            super(MLP_cls,self).__init__()
            self.lin1 = nn.Linear(in_dim,128)
            self.lin2 = nn.Linear(128,64)
            self.lin3 = nn.Linear(64,10)
            self.relu = nn.ReLU()
            init.xavier_uniform_(self.lin1.weight)
            init.xavier_uniform_(self.lin2.weight)
            init.xavier_uniform_(self.lin3.weight)
    
        def forward(self,x):
            x = x.view(-1,3*32*32)
            x = self.lin1(x)
            x = self.relu(x)
            x = self.lin2(x)
            x = self.relu(x)
            x = self.lin3(x)
            x = self.relu(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    main.py

    import torch
    import torch.nn as nn
    import torchvision
    from torch.utils.data import DataLoader
    import torch.optim as optim
    from model import MLP_cls,CNN_cls
    
    
    seed = 42
    torch.manual_seed(seed)
    batch_size_train = 64
    batch_size_test  = 64
    epochs = 10
    learning_rate = 0.01
    momentum = 0.5
    net = MLP_cls()
    
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10('./data/', train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.5,), (0.5,))
                                   ])),
        batch_size=batch_size_test, shuffle=True)
    
    optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    
    print("****************Begin Training****************")
    net.train()
    for epoch in range(epochs):
        run_loss = 0
        correct_num = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            out = net(data)
            _,pred = torch.max(out,dim=1)
            optimizer.zero_grad()
            loss = criterion(out,target)
            loss.backward()
            run_loss += loss
            optimizer.step()
            correct_num  += torch.sum(pred==target)
        print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))
    
    
    
    print("****************Begin Testing****************")
    net.eval()
    test_loss = 0
    test_correct_num = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        test_loss += criterion(out,target)
        test_correct_num  += torch.sum(pred==target)
    print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65

    参数设置

    './data/' #数据保存路径
    seed = 42 #随机种子
    batch_size_train = 64
    batch_size_test  = 64
    epochs = 10
    
    optim --> SGD
    learning_rate = 0.01
    momentum = 0.5
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    注意事项

    CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

    运行图

    在这里插入图片描述

  • 相关阅读:
    今夕摄影影楼管理系统
    JAVA基础12:字符串(上)
    Python读取postgresql数据库
    Python——异常
    Java面向对象设计 - Java泛型约束
    《MongoDB入门教程》第19篇 文档更新之$rename操作符
    JAVA坝上长尾鸡养殖管理系统计算机毕业设计Mybatis+系统+数据库+调试部署
    Nginx重写功能
    CAT-Seg: Cost Aggregation for Open-Vocabulary Semantic Segmentation
    ceph 常用命令
  • 原文地址:https://blog.csdn.net/m0_59741202/article/details/132998788