• 关于pytorch的数据处理-数据加载Dataset


    目录

    1. 数据加载

    2. Dataset 

    __init__

    __getitem__

    __len__

    测试一下

    完整代码

    3. Dataset - ImageFolder


    1. 数据加载

    最近在使用 Unet 做图像分割,设计到 处理数据有关的工作,查了点资料,做一些简单的总结

    在pytorch 中,数据的加载可以通过自定义的数据集对象实现,这里是Dataset 类,实现自定义的数据集需要继承Dataset,并且实现两个方法

    • __getitem__: 返回一个样本
    • __len__: 返回样本的数量

    其实,之前一直都有用过Dataset类,但是都是直接调库的,所以导致现在对Dataset有点熟悉又有点陌生的感觉

    之前下载CIFAR10 数据集的时候,用的都是:

    •  这里的torchvision 提供数据集
    •  torchvision 里面的dataset 就包含了各种的数据集

    2. Dataset 

    接下来,通过猫和狗的图像介绍Dataset ,介绍如何处理数据

    首先先创建一个文件夹,里面随便上网上下载几张猫和狗的图片,放在同一个文件夹下

    这里的猫狗文件名被改了,后面数字是随机输的,目的是通过 ' . ' 前面的dog和cat生成label


    然后提前导入下面的库文件

     


    __init__

    接下来定义初始化方法

     init 里面是初始化方法,例如传入图片的路径,或者要不要选择预处理等等

    这里并不实际加载图片,只是指定路径,真正的读取图片在getitem方法里面

    os.listdir : 会将data下面所以的文件读取,放在imgs里面,打印结果是上面的注释

    然后self.imgs 会将imgs里面的路径和root路径 拼接在一块,输出结果如下:

    ['./data/cat.15454.jpg', './data/cat.445.jpg', './data/cat.46456.jpg', './data/cat.656165.jpg', './data/dog.123.jpg', './data/dog.15564.jpg', './data/dog.4545.jpg', './data/dog.456465.jpg']

    imgs 里面是具体文件的路径,root里面是文件夹的路径

    __getitem__

    上面说过,getitem 是返回一个样本,所以说这里是将结果返回的。那么返回之前做的处理数据的操作,也在__getitem__里面。

     

    这里的img_path 通过self.imgs[index] 会将self.imgs里面的内容一个个读取出来

    而self.imgs 里面是下图,每个数据的路径

     所以self.imgs[index]会遍历self.imgs 里面的路径,返回给img_path

    打印结果: 

     然后,就可以根据每个路径的id去做label了。将img_path 路径按照 '/ '分割,-1代表取最后一个字符串,如果里面有dog就为1,cat就为0,例如:下面的例子打印的就是Yes

     

    最后,看看有没有预处理transforms ,然后返回data和label就行了

    __len__

    返回样本的个数 = 图片路径的个数 

    测试一下

    最后的结果如下:

     

    所以通过上面的代码,就可以实现一个自定义自己数据集的办法,并且可以获取

    完整代码

    1. import torch
    2. import torchvision.datasets
    3. from torch.utils.data import Dataset # 继承Dataset类
    4. import os
    5. from PIL import Image
    6. import numpy as np
    7. from torchvision import transforms
    8. # 预处理
    9. data_transform = transforms.Compose([
    10. transforms.Resize((224,224)), # 缩放图像
    11. transforms.ToTensor(), # 转为Tenso
    12. transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) # 标准化
    13. ])
    14. class DogCat(Dataset): # 数据处理
    15. def __init__(self,root,transforms = None): # 初始化,指定路径,是否预处理等等
    16. #['cat.15454.jpg', 'cat.445.jpg', 'cat.46456.jpg', 'cat.656165.jpg', 'dog.123.jpg', 'dog.15564.jpg', 'dog.4545.jpg', 'dog.456465.jpg']
    17. imgs = os.listdir(root)
    18. self.imgs = [os.path.join(root,img) for img in imgs] # 取出root下所有的文件
    19. self.transforms = data_transform # 图像预处理
    20. def __getitem__(self, index): # 读取图片
    21. img_path = self.imgs[index]
    22. label = 1 if 'dog' in img_path.split('/')[-1] else 0 # dog -> 1,cat -> 0
    23. data = Image.open(img_path)
    24. if self.transforms: # 图像预处理
    25. data = self.transforms(data)
    26. return data,label
    27. def __len__(self):
    28. return len(self.imgs)
    29. dataset = DogCat('./data/',transforms=True)
    30. for img,label in dataset:
    31. print('img:',img.size(),'label:',label)
    32. '''
    33. img: torch.Size([3, 224, 224]) label: 0
    34. img: torch.Size([3, 224, 224]) label: 0
    35. img: torch.Size([3, 224, 224]) label: 0
    36. img: torch.Size([3, 224, 224]) label: 0
    37. img: torch.Size([3, 224, 224]) label: 1
    38. img: torch.Size([3, 224, 224]) label: 1
    39. img: torch.Size([3, 224, 224]) label: 1
    40. img: torch.Size([3, 224, 224]) label: 1
    41. '''

    3. Dataset - ImageFolder

    ImageFolder 可以更好的将上述的猫狗打好标签

    ImageFolder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件名为类名

    例如:将上述的图片放在不同的文件夹下

    文件名的大小写要一致,如首字母大写,都要大写

     

     这样ImageFolder 读取的label就是按照文件名顺序排序成为字典的,也就是{类名:序号}。就是类名+对应的label

    可以通过 .class_to_idx 查看

     

    打印结果为:

    ['Cat', 'Dog']



    {'Cat': 0, 'Dog': 1}



    Dataset ImageFolder
        Number of datapoints: 8
        Root location: ./DogCat/



    [('./DogCat/Cat\\cat.15454.jpg', 0), ('./DogCat/Cat\\cat.445.jpg', 0), ('./DogCat/Cat\\cat.46456.jpg', 0), ('./DogCat/Cat\\cat.656165.jpg', 0), ('./DogCat/Dog\\dog.123.jpg', 1), ('./DogCat/Dog\\dog.15564.jpg', 1), ('./DogCat/Dog\\dog.4545.jpg', 1), ('./DogCat/Dog\\dog.456465.jpg', 1)]
     

    这个就是为什么 pytorch 搭建AlexNet 对花进行分类 这里面对花分类,文件夹的顺序就是这个类别的顺序

     

    最后就是:

     

  • 相关阅读:
    太空射击第09课:精灵动画
    Java面试题 JVM 篇 Redis篇 Spring篇
    刷题记录:牛客NC51222Strategic game
    猿创征文|GaussDB(for openGauss):基于 GaussDB 迁移、智能管理构建应用解决方案
    火山引擎 DataTester 应用故事:一个A/B测试,将产品DAU提升了数十万
    Lange电桥的设计
    NASM汇编教程翻译01 第一讲 Hello, World!
    词向量word2vec(图学习参考资料)
    C03-【计算机二级】Excel操作题(2)全国人口普查数据的统计分析
    如何使用Photino创建Blazor项目进行跨平台
  • 原文地址:https://blog.csdn.net/qq_44886601/article/details/127869770