• 【使用ImageFolder加载数据】


    1 使用ImageFolder的前提条件

    诸如图片的两分类问题,训练和测试的图片是分别存放好的,如下目录树

    +---test
    |   +---airplane
    			airplane_561.jpg
    			...
    			airplane_700.jpg
    |   \---lake
    			lake_561.jpg
    			...
    			lake_700.jpg
    \---train
        +---airplane
    			airplane_001.jpg
    			...
    			airplane_560.jpg
        \---lake
    			lake_001.jpg
    			...
    			lake_560.jpg
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    分别存放好后,使用如下语句读取加载:

    import torchvision
    
    train_dir = r'2_class/train'
    test_dir = r'2_class/test'
    
    from torchvision import transforms
    
    transform = transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                           std=[0.5, 0.5, 0.5])
    ])
    
    train_ds = torchvision.datasets.ImageFolder(
                   train_dir,
                   transform=transform
    )
    
    
    test_ds = torchvision.datasets.ImageFolder(
                   test_dir,
                   transform=transform
    )
    
    print(train_ds.classes)
    print(train_ds.class_to_idx)
    print(len(train_ds), len(test_ds))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    输出如下:

    ['airplane', 'lake']
    {'airplane': 0, 'lake': 1}
    1120 280
    
    • 1
    • 2
    • 3

    如果是其它分类问题,也可以按照这种方法加载数据

    2 批量加载数据

    BATCHSIZE = 16
    train_dl = torch.utils.data.DataLoader(
                                           train_ds,
                                           batch_size=BATCHSIZE,
                                           shuffle=True
    )
    test_dl = torch.utils.data.DataLoader(
                                           test_ds,
                                           batch_size=BATCHSIZE,
    )
    
    imgs, labels = next(iter(train_dl))
    print(imgs.shape)   #一批次形状
    print(imgs[0].shape)#一张图形状
    
    im = imgs[0].permute(1, 2, 0)   #设置通道数为最后一维
    print(im.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    输出如下:

    torch.Size([16, 3, 256, 256])
    torch.Size([3, 256, 256])
    torch.Size([256, 256, 3])
    
    • 1
    • 2
    • 3

    3 反转类别序号和关键字,绘制样例图

    id_to_class = dict((v, k) for k, v in train_ds.class_to_idx.items())
    print(id_to_class)
    
    • 1
    • 2

    输出如下:

    {0: 'airplane', 1: 'lake'}
    
    • 1

    绘制样例图:

    plt.figure(figsize=(12, 8))
    for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):
        img = (img.permute(1, 2, 0).numpy() + 1)/2
        plt.subplot(2, 3, i+1)
        plt.title(id_to_class.get(label.item()))
        plt.xticks([])
        plt.yticks([])
        plt.imshow(img)
        plt.savefig('pics/4-2.jpg', dpi=400)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    在这里插入图片描述

  • 相关阅读:
    Ubuntu20.04下载opencv3.4--未完善
    当他们在私域里,掌握了分寸感
    「北大社送书」学习Flutter编程 — 《从零基础到精通Flutter开发》
    zk分布式Job 实现的业务逻辑
    PHP 程序员为什么依然是外包公司的香饽饽?
    特性Attribute
    雾天行人车辆检测
    1358:中缀表达式值(expr)
    传输层 SACK与选择性重传算法
    OpenGLES:绘制一个混色旋转的3D球体
  • 原文地址:https://blog.csdn.net/m0_46256255/article/details/133178412