• PyTorch入门之【AlexNet】


    参考文献:https://www.bilibili.com/video/BV1DP411C7Bw/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
    AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。

    大纲

    在这里插入图片描述
    各个文件的作用:

    • data就是数据集
    • dataloader.py就是数据集的加载以及实例初始化
    • model.py就是AlexNet模块的定义
    • train.py就是模型的训练
    • test.py就是模型的测试

    dataloader

    import torch
    import torchvision
    import torchvision.transforms as transforms
    
    import matplotlib.pyplot as plt
    import numpy as np
    
    
    # define the dataloader
    transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 16
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False)
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    
    if __name__ == '__main__':
        # get some random training images
        dataiter = iter(train_loader)
        images, labels = next(dataiter)
    
        # print labels
        print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
    
        # show images
        img_grid = torchvision.utils.make_grid(images)
        img_grid = img_grid / 2 + 0.5
        npimg = img_grid.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    model

    import torch.nn as nn
    import torch
    
    class AlexNet(nn.Module):
        def __init__(self, num_classes=10):
            super(AlexNet, self).__init__()
            self.conv_1 = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
                nn.BatchNorm2d(96),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.conv_2 = nn.Sequential(
                nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.conv_3 = nn.Sequential(
                nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(384),
                nn.ReLU())
            self.conv_4 = nn.Sequential(
                nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(384),
                nn.ReLU())
            self.conv_5 = nn.Sequential(
                nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.fc_1 = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(9216, 4096),
                nn.ReLU())
            self.fc_2 = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(4096, 4096),
                nn.ReLU())
            self.fc_3= nn.Sequential(
                nn.Linear(4096, num_classes))
            
        def forward(self, x):
            out = self.conv_1(x)
            out = self.conv_2(out)
            out = self.conv_3(out)
            out = self.conv_4(out)
            out = self.conv_5(out)
            out = out.reshape(out.size(0), -1)
            out = self.fc_1(out)
            out = self.fc_2(out)
            out = self.fc_3(out)
            return out
    
    if __name__ == '__main__':
        model = AlexNet()
        print(model)
        x = torch.randn(1, 3, 224, 224)
        y = model(x)
        print(y.size())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58

    train

    import torch
    import torch.nn as nn
    
    from dataloader import train_loader, test_loader
    from model import AlexNet
    
    
    # define the hyperparameters
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = 10
    num_epochs = 20
    learning_rate = 1e-3
    
    
    # load the model
    model = AlexNet(num_classes).to(device)
    
    
    # loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  
    
    
    # train the model
    total_len = len(train_loader)
    
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            # move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)
            
            # forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (i+1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch+1, num_epochs, i+1, total_len, loss.item()
                ))
                
        # Validation
        with torch.no_grad():
            model.eval()
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            model.train()
            print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))
    
    # save the model checkpoint
    torch.save(model.state_dict(), 'alexnet.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63

    test

    import torch
    
    from dataloader import test_loader, classes
    from model import AlexNet
    
    
    # load the pretrained model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = AlexNet().to(device)
    model.load_state_dict(torch.load('alexnet.pth', map_location=device))
    
    # test the pretrained model on CIFAR-10 test data
    with torch.no_grad():
        model.eval()
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
  • 相关阅读:
    SpringBoot自动装配原理
    技术分享 | 云原生多模型 NoSQL 概述
    elementPlus Pagination 分页怎样变中文
    广告词如何使用更正规?行者AI告诉你
    Xamarin.Andorid实现界面弹框
    奇异谱分析(SSA)matlab
    【图论】【并集查找】【C++算法】928. 尽量减少恶意软件的传播 II
    6. 清理过程
    iNFTnews|国内数藏平台大撤退,寒冬之下海外市场是否有出路?
    python多进程之间共享内存
  • 原文地址:https://blog.csdn.net/qq_46527915/article/details/133623578