• PyTorch搭建LSTM实现服装分类(FashionMNIST)


    FashionMNIST 数据集官网:https://github.com/zalandoresearch/fashion-mnist.

    这里不再介绍该数据集,如需了解请前往官网。

    思路: 数据集中的每张图片都是尺寸为 ( 28 , 28 ) (28,28) (28,28) 的灰度图。我们可以将其看作 28 × 28 28\times28 28×28 的数字矩阵,将该矩阵按行进行逐行分块可得一个长度为 28 28 28 的序列,且序列中的每个 “词元” 对应的特征维数也是 28 28 28


    运行环境:

    • 系统:Ubuntu 20.04;
    • GPU:RTX 3090;
    • Pytorch:1.11;
    • Python:3.8

    import numpy as np
    import matplotlib.pyplot as plt
    
    import torch
    import torchvision
    import torch.nn as nn
    from torch.utils.data import DataLoader
    
    # Data Preprocessing
    train_data = torchvision.datasets.FashionMNIST(root='data',
                                                   train=True,
                                                   transform=torchvision.transforms.ToTensor(),
                                                   download=True)
    test_data = torchvision.datasets.FashionMNIST(root='data',
                                                  train=False,
                                                  transform=torchvision.transforms.ToTensor(),
                                                  download=True)
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=64, num_workers=4)
    
    
    # Model building
    class LSTM(nn.Module):
    
        def __init__(self):
            super().__init__()
            self.lstm = nn.LSTM(28, 64, num_layers=2)
            self.linear = nn.Linear(64, 10)
    
        def forward(self, x):
            output, (h_n, c_n) = self.lstm(x, None)
            return self.linear(h_n[0])
    
    
    def setup_seed(seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    
    # Setup
    setup_seed(42)
    
    NUM_EPOCHS = 20
    LR = 4e-3
    train_loss, test_loss, test_acc = [], [], []
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    lstm = LSTM()
    lstm.to(device)
    
    critertion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lstm.parameters(), lr=LR)
    
    # Training and testing
    for epoch in range(NUM_EPOCHS):
        print(f'[Epoch {epoch + 1}]', end=' ')
        avg_train_loss, avg_test_loss, correct = 0, 0, 0
    
        # train
        lstm.train()
        for batch_idx, (X, y) in enumerate(train_loader):
            # (64, 1, 28, 28) -> (28, 64, 28)
            X = X.squeeze().movedim(0, 1)
            X, y = X.to(device), y.to(device)
    
            # forward
            output = lstm(X)
            loss = critertion(output, y)
            avg_train_loss += loss
    
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        avg_train_loss /= (batch_idx + 1)
        train_loss.append(avg_train_loss.item())
    
        # test
        lstm.eval()
        with torch.no_grad():
            for batch_idx, (X, y) in enumerate(test_loader):
                X = X.squeeze().movedim(0, 1)
                X, y = X.to(device), y.to(device)
    
                pred = lstm(X)
                loss = critertion(pred, y)
                avg_test_loss += loss
                correct += (pred.argmax(1) == y).sum().item()
    
        avg_test_loss /= (batch_idx + 1)
        test_loss.append(avg_test_loss.item())
        correct /= len(test_loader.dataset)
        test_acc.append(correct)
    
        print(
            f"train loss: {train_loss[-1]:.4f} | test loss: {test_loss[-1]:.4f} | test acc: {correct:.4f}"
        )
    
    # Plot
    x = np.arange(1, 21)
    plt.plot(x, train_loss, label="train loss")
    plt.plot(x, test_loss, label="test loss")
    plt.plot(x, test_acc, label="test acc")
    plt.xlabel("epoch")
    plt.legend(loc="best", fontsize=12)
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109

    输出结果:

    [Epoch 1] train loss: 0.6602 | test loss: 0.5017 | test acc: 0.8147
    [Epoch 2] train loss: 0.4089 | test loss: 0.3979 | test acc: 0.8566
    [Epoch 3] train loss: 0.3577 | test loss: 0.3675 | test acc: 0.8669
    [Epoch 4] train loss: 0.3268 | test loss: 0.3509 | test acc: 0.8751
    [Epoch 5] train loss: 0.3098 | test loss: 0.3395 | test acc: 0.8752
    [Epoch 6] train loss: 0.2962 | test loss: 0.3135 | test acc: 0.8854
    [Epoch 7] train loss: 0.2823 | test loss: 0.3377 | test acc: 0.8776
    [Epoch 8] train loss: 0.2720 | test loss: 0.3196 | test acc: 0.8835
    [Epoch 9] train loss: 0.2623 | test loss: 0.3120 | test acc: 0.8849
    [Epoch 10] train loss: 0.2547 | test loss: 0.2981 | test acc: 0.8931
    [Epoch 11] train loss: 0.2438 | test loss: 0.3140 | test acc: 0.8882
    [Epoch 12] train loss: 0.2372 | test loss: 0.3043 | test acc: 0.8909
    [Epoch 13] train loss: 0.2307 | test loss: 0.2977 | test acc: 0.8918
    [Epoch 14] train loss: 0.2219 | test loss: 0.2888 | test acc: 0.8970
    [Epoch 15] train loss: 0.2187 | test loss: 0.2946 | test acc: 0.8959
    [Epoch 16] train loss: 0.2132 | test loss: 0.2894 | test acc: 0.8985
    [Epoch 17] train loss: 0.2061 | test loss: 0.2835 | test acc: 0.9014
    [Epoch 18] train loss: 0.2028 | test loss: 0.2954 | test acc: 0.8971
    [Epoch 19] train loss: 0.1966 | test loss: 0.2952 | test acc: 0.8986
    [Epoch 20] train loss: 0.1922 | test loss: 0.2910 | test acc: 0.9011
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    相应的曲线:

    在这里插入图片描述


    一些心得

    • 切勿直接使用 X = X.reshape(28, -1, 28),否则 X 对应的将不是原来的图片(读者可自行尝试使用 torchvision.transforms.ToPILImage 去输出 X 对应的图片观察效果)。
    • 学习率相同的情况下,SGD 的效果没有 Adam 好。
  • 相关阅读:
    TCP/IP协议详解
    禅道:提bug、管理case 7.0
    增材制造中如何使用HOOPS进行3D模型轻量化处理?
    纯干货|AI辅助写论文的正确打开方式!
    Codeforces Round #797 (Div. 3)个人题解
    应广单片机实现红蓝双色爆闪灯
    docker镜像的创建
    云原生 | Kubernetes - Helm - Prometheus 实现资源限制
    打破数据分析壁垒:SPSS复习必备(四)
    Linux shell编程学习笔记20:case ... esac、continue 和break语句
  • 原文地址:https://blog.csdn.net/raelum/article/details/125432230