• 【pytorch】CNN实战-花朵种类识别


    数据集

    利用Kaggle上的一个公开数据集,下载连接如下:
    https://www.kaggle.com/datasets/alxmamaev/flowers-recognition
    其是一些花的照片,共有5类,四千多张照片。

    数据处理

    整个数据集并不大,因此可以将其先读入到内存(显存中),而不再需要每次要用到的时候再从硬盘中读取,能够有效地提升运行速度。
    而图片的数量并不多,因此还需要用到图片增广技术。

    读取数据集

    Kaggle上的数据已经按照文件夹将图片分好类了,因此读取图片的时候,需要按照文件夹来归类。

    class Flower_Dataset(Dataset):
        def __init__(self, path , is_train, augs):
            data_root = pathlib.Path(path)
            all_image_paths = list(data_root.glob('*/*'))
            self.all_image_paths = [str(path) for path in all_image_paths]
            label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
            label_to_index = dict((label, index) for index, label in enumerate(label_names))
            self.all_image = [cv.imread(path) for path in self.all_image_paths]
            self.all_image_labels = [label_to_index[path.parent.name] for path in all_image_paths]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    图片增广

    考虑花的图片,水平变换之后仍然是一朵花,因此可以使用此种增广方式。
    此为,亮度、对比度等调整均可使用。

    color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
    augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug])
    
    • 1
    • 2

    迭代器

    每次从数据集中抽取一个批量的大小。
    一般情况下使用打乱顺序的方式。

    train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers= 4)
    test_iter = DataLoader(test_set, batch_size=batch_size, num_workers= 4)
    
    • 1
    • 2

    CNN模型

    采用经典的resnet模型,由于数据集大小有限,不宜采用过于复杂的网络,故在此选用了resnet18,其共有68层,不算太深,具体结构如下:

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 112, 112]           9,408
           BatchNorm2d-2         [-1, 64, 112, 112]             128
                  ReLU-3         [-1, 64, 112, 112]               0
             MaxPool2d-4           [-1, 64, 56, 56]               0
                Conv2d-5           [-1, 64, 56, 56]          36,864
           BatchNorm2d-6           [-1, 64, 56, 56]             128
                  ReLU-7           [-1, 64, 56, 56]               0
                Conv2d-8           [-1, 64, 56, 56]          36,864
           BatchNorm2d-9           [-1, 64, 56, 56]             128
                 ReLU-10           [-1, 64, 56, 56]               0
           BasicBlock-11           [-1, 64, 56, 56]               0
               Conv2d-12           [-1, 64, 56, 56]          36,864
          BatchNorm2d-13           [-1, 64, 56, 56]             128
                 ReLU-14           [-1, 64, 56, 56]               0
               Conv2d-15           [-1, 64, 56, 56]          36,864
          BatchNorm2d-16           [-1, 64, 56, 56]             128
                 ReLU-17           [-1, 64, 56, 56]               0
           BasicBlock-18           [-1, 64, 56, 56]               0
               Conv2d-19          [-1, 128, 28, 28]          73,728
          BatchNorm2d-20          [-1, 128, 28, 28]             256
                 ReLU-21          [-1, 128, 28, 28]               0
               Conv2d-22          [-1, 128, 28, 28]         147,456
          BatchNorm2d-23          [-1, 128, 28, 28]             256
               Conv2d-24          [-1, 128, 28, 28]           8,192
          BatchNorm2d-25          [-1, 128, 28, 28]             256
                 ReLU-26          [-1, 128, 28, 28]               0
           BasicBlock-27          [-1, 128, 28, 28]               0
               Conv2d-28          [-1, 128, 28, 28]         147,456
          BatchNorm2d-29          [-1, 128, 28, 28]             256
                 ReLU-30          [-1, 128, 28, 28]               0
               Conv2d-31          [-1, 128, 28, 28]         147,456
          BatchNorm2d-32          [-1, 128, 28, 28]             256
                 ReLU-33          [-1, 128, 28, 28]               0
           BasicBlock-34          [-1, 128, 28, 28]               0
               Conv2d-35          [-1, 256, 14, 14]         294,912
          BatchNorm2d-36          [-1, 256, 14, 14]             512
                 ReLU-37          [-1, 256, 14, 14]               0
               Conv2d-38          [-1, 256, 14, 14]         589,824
          BatchNorm2d-39          [-1, 256, 14, 14]             512
               Conv2d-40          [-1, 256, 14, 14]          32,768
          BatchNorm2d-41          [-1, 256, 14, 14]             512
                 ReLU-42          [-1, 256, 14, 14]               0
           BasicBlock-43          [-1, 256, 14, 14]               0
               Conv2d-44          [-1, 256, 14, 14]         589,824
          BatchNorm2d-45          [-1, 256, 14, 14]             512
                 ReLU-46          [-1, 256, 14, 14]               0
               Conv2d-47          [-1, 256, 14, 14]         589,824
          BatchNorm2d-48          [-1, 256, 14, 14]             512
                 ReLU-49          [-1, 256, 14, 14]               0
           BasicBlock-50          [-1, 256, 14, 14]               0
               Conv2d-51            [-1, 512, 7, 7]       1,179,648
          BatchNorm2d-52            [-1, 512, 7, 7]           1,024
                 ReLU-53            [-1, 512, 7, 7]               0
               Conv2d-54            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-55            [-1, 512, 7, 7]           1,024
               Conv2d-56            [-1, 512, 7, 7]         131,072
          BatchNorm2d-57            [-1, 512, 7, 7]           1,024
                 ReLU-58            [-1, 512, 7, 7]               0
           BasicBlock-59            [-1, 512, 7, 7]               0
               Conv2d-60            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-61            [-1, 512, 7, 7]           1,024
                 ReLU-62            [-1, 512, 7, 7]               0
               Conv2d-63            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-64            [-1, 512, 7, 7]           1,024
                 ReLU-65            [-1, 512, 7, 7]               0
           BasicBlock-66            [-1, 512, 7, 7]               0
    AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
               Linear-68                    [-1, 5]           2,565
    ================================================================
    Total params: 11,179,077
    Trainable params: 11,179,077
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 62.79
    Params size (MB): 42.64
    Estimated Total Size (MB): 106.00
    ----------------------------------------------------------------
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81

    微调技术

    考虑到此数据集中的图片与ImageNet比较类似,故可以使用该技术。
    唯一需要修改的地方,就是最后一层,将原有的输出设为5。
    此外每层的学习率也需要同样修改。

    net = torchvision.models.resnet18(pretrained=True)
        
    net.fc = nn.Linear(net.fc.in_features, 5)
    nn.init.xavier_uniform_(net.fc.weight)
    summary(net , input_size=(3,224,224) , device="cpu")
    
    lr = 0.0005
    loss = nn.CrossEntropyLoss(reduction="mean")
    
    params_1x = [param for name, param in net.named_parameters()
        if name not in ["fc.weight", "fc.bias"]]
    trainer = torch.optim.SGD([{'params': params_1x},{'params': net.fc.parameters(),'lr': lr * 80}],lr=lr, weight_decay=0.001)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    训练

    此部分与其他神经网络较为类似,就不在赘述。

    from tqdm import tqdm
    import numpy as np
    
    #Training
    Accuracies = []
    Losses = []
    T_Accuracies = []
    T_Losses = []
    for epoch in range(epochs):
        net.train()
        loop = tqdm(enumerate(train_iter), total = len(train_iter)) # 定义进度条
        loop.set_description(f'Epoch [{epoch+ 1}/{epochs}]')# 设置开头
        T_Accuracies.append(0)
        T_Losses.append(0)
        for index, (X, Y) in loop:
            scores = net(X)
            l = loss(scores, Y)
            trainer.zero_grad()
            l.backward()
            
            _ , predictions = scores.max(1)
            num_correct = (predictions == Y).sum()
            running_train_acc = float(num_correct) / float(X.shape[0])
            if index == 0:
                T_Accuracies[-1] = running_train_acc
                T_Losses[-1] = l.item()
            else:
                T_Accuracies[-1] = T_Accuracies[-1] * 0.9 + 0.1 * running_train_acc
                T_Losses[-1] = T_Losses[-1] * 0.9 + 0.1 * l.item()
            
            loop.set_postfix(loss='{:.3f}'.format(T_Losses[-1]), accuracy='{:.3f}'.format(T_Accuracies[-1] )) # 定义结尾
            
            trainer.step()
            pass
        a , b = testing()
        Accuracies.append(a)
        Losses.append(b)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    结果

    根据训练集与测试集的数据,绘制出如下图像:
    在这里插入图片描述
    可以看到无论是训练集还是测试集的正确率都比较高,说明微调技术是有用的。
    而且测试正确率在第五轮的时候已经超过了90%,可以说在短时间内就达到了一个较高的水平。
    此外,训练集正确率相比于测试集正确率偏低,这是由于在训练集上使用了图像增广而测试集没有的。

    完整代码

    下载链接

  • 相关阅读:
    小程序只用云开发,如何发送公众号模板消息?
    proxy配置
    支持导入ics文件的提醒待办类工具
    2.5 晶体管单管放大电路的三种基本接法
    江开2024年春《大学英语(B)(2) 060052》过程性考核作业4参考答案
    springmvc中异步转同步
    实现两个div水平对齐
    Qt QtCreator调试Qt源码配置
    stm32 串口发送和接收
    Grafana+Prometheus打造运维监控系统(一)-安装篇
  • 原文地址:https://blog.csdn.net/lijf2001/article/details/125981978