• PyTorch入门之【CNN】


    参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
    书接上回的MLP故本章就不详细解释了

    目录

    train

    import torch
    from torchvision.transforms import ToTensor
    from torchvision import datasets
    import torch.nn as nn
    
    # load MNIST dataset
    training_data = datasets.MNIST(
        root='../02_dataset/data',
        train=True,
        download=True,
        transform=ToTensor()
    )
    
    train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)
    
    # define a CNN model
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv_1 = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=3, stride=1),
                nn.BatchNorm2d(32),
                nn.ReLU()
            )
            self.conv_2 = nn.Sequential(
                nn.Conv2d(32, 64, kernel_size=3, stride=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            )
            self.maxpool = nn.MaxPool2d(2)
            self.flatten = nn.Flatten()
            self.fc_1 = nn.Sequential(
                nn.Linear(9216, 128),
                nn.BatchNorm1d(128),
                nn.ReLU()
            )
            self.fc_2 = nn.Linear(128, 10)
    
        def forward(self, x):
            x = self.conv_1(x)
            x = self.conv_2(x)
            x = self.maxpool(x)
            x = self.flatten(x)
            x = self.fc_1(x)
            logits = self.fc_2(x)
            return logits
    
    # create a CNN model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cnn = CNN().to(device)
    optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    
    # train the model
    num_epochs = 20
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}\n-------------------------------')
        for idx, (img, label) in enumerate(train_data_loader):
            size = len(train_data_loader.dataset)
            img, label = img.to(device), label.to(device)
    
            # compute prediction error
            pred = cnn(img)
            loss = loss_fn(pred, label)
    
            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if idx % 400 == 0:
                loss, current = loss.item(), idx*len(img)
                print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')
    
    # save the model
    torch.save(cnn.state_dict(), 'cnn.pth')
    print('Saved PyTorch Model State to cnn.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78

    test

    import torch
    from torchvision import datasets
    from torchvision import transforms
    from torchvision.transforms import ToTensor
    from torchvision.datasets import ImageFolder
    import torch.nn as nn
    
    # load test data
    test_data = datasets.MNIST(
        root='../02_dataset/data',
        train=False,
        download=True,
        transform=ToTensor()
    )
    test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)
    
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])
    my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
    my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)
    
    # define a CNN model
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv_1 = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=3, stride=1),
                nn.BatchNorm2d(32),
                nn.ReLU()
            )
            self.conv_2 = nn.Sequential(
                nn.Conv2d(32, 64, kernel_size=3, stride=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            )
            self.maxpool = nn.MaxPool2d(2)
            self.flatten = nn.Flatten()
            self.fc_1 = nn.Sequential(
                nn.Linear(9216, 128),
                nn.BatchNorm1d(128),
                nn.ReLU()
            )
            self.fc_2 = nn.Linear(128, 10)
    
        def forward(self, x):
            x = self.conv_1(x)
            x = self.conv_2(x)
            x = self.maxpool(x)
            x = self.flatten(x)
            x = self.fc_1(x)
            logits = self.fc_2(x)
            return logits
    
    # load the pretrained model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    cnn = CNN()
    cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
    cnn.eval().to(device)
    
    # test the pretrained model on MNIST test data
    size = len(test_data_loader.dataset)
    correct = 0
    
    with torch.no_grad():
        for img, label in test_data_loader:
            img, label = img.to(device), label.to(device)
            pred = cnn(img)
    
            correct += (pred.argmax(1) == label).type(torch.float).sum().item()
    
    correct /= size
    print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')
    
    # test the pretrained model on my MNIST test data
    size = len(my_mnist_loader.dataset)
    correct = 0
    
    with torch.no_grad():
        for img, label in my_mnist_loader:
            img, label = img.to(device), label.to(device)
            pred = cnn(img)
    
            correct += (pred.argmax(1) == label).type(torch.float).sum().item()
    
    correct /= size
    print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
  • 相关阅读:
    vue中将three.js导入的3D模型中原本带有的动画进行播放
    .net8 blazor auto模式很爽(五)读取sqlite并显示(2)
    mybatis 14: 多对一关联查询
    这个算法不一般,控制拥塞有一手!
    场外期权交易流程以及参与方式是什么?
    go nil介绍
    SpringBoot SpringBoot 原理篇 2 自定义starter 2.7 开启yml 提示功能
    基于JAVA网上花店计算机毕业设计源码+数据库+lw文档+系统+部署
    Python 装饰器
    【网络通信】初探网络层次结构(OSI七层网络模型)
  • 原文地址:https://blog.csdn.net/qq_46527915/article/details/133621118