• pytorch初学笔记(一):如何加载数据和Dataset实战


    目录

     一、Dataset初识以及项目前期准备工作

    二、MyData类

    2.1 在python中定义类和方法

    2.2 定义MyClass类

    Dataset

    2.3 获取图片

    2.4 使用控制台调试对应信息

    1. 获取ants集中第一章图片的绝对路径

    2. 读取对应路径的图片

    3. 显示图片:show方法

    4. 获取图片信息列表

    三、完善MyData类

    3.1  初始化方法中需要的参数和方法

    3.2 初始化init方法的书写

    3.3 getitem方法的书写

    3.4 生成实例 

    3.4 两个数据集的生成与相加操作

    1. 生成蚂蚁和蜜蜂数据集 

    2. 数据集相加

    四、完整代码

    五、使用修改后数据集的代码练习


    python文件、python控制台和jupter notebook的区别

    遇到的问题:

    1. jupyter notebook中配置pytorch

    (71条消息) jupyter notebook中使用pytorch_一子慢的博客-CSDN博客_jupyternotebook使用pytorch

    2. pycharm中matplotlib使用失败

    (71条消息) Pycharm导入matplotlib失败的解决办法_c472769019的博客-CSDN博客_matplotlib导入失败

     一、Dataset初识以及项目前期准备工作

     在notebook中使用help方法查看dataset类的功能以及操作:

    • 想要使用dataset都需要继承Dataset这个父类
    • 需要重写__getitem__方法和__len__方法
    • __getitem__():由给定的key获取数据集中每一个图片的操作函数
    • __len__():获取数据集中图片大小的函数

    前置操作

    1. 把数据集移动到项目所在的目录文件夹下

    2. 右击想要查看路径的文件夹/图片: 

     

     可以复制需要的绝对路径/相对路径

    二、MyData类

    2.1 在python中定义类和方法

    • 在python中定义类的要求:class关键字定义类,后面跟着类的全名,括号(object)表示该类是从哪个类中继承下来的,如果没有合适的继承类,则使用object类,这是所有类都会继承的类。 
    • 在类里定义方法的要求:在类中定义方法时,第一个参数必须是self。
    • 在类中定义方法的要求:self变量无需传递,其他参数正常传入。

    例: 

    2.2 定义MyClass类

    • 从torch工具箱中导入Dataset模块
    from torch.utils.data import Dataset

    Dataset

    Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:

    1. __getitem__() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本,可以使用对象【item】进行访问

    2. __len__() 该方法返回数据集的总长度

     

    首先,重写init方法和getitem方法,后期重写len()方法

    2.3 获取图片

    导入图片需要获取对应的图片image和对应的标签label,也需要获取图片所在的位置img_path

     读取图片需要导入的模块

    1. # 读取图片
    2. from PIL import Image

    2.4 使用控制台调试对应信息

    控制台作用:可以显示定义的变量和相关属性

    1. 获取ants集中第一章图片的绝对路径

    存入img_path变量中,复制后的路径需要再加一个双斜线进行转义。

    2. 读取对应路径的图片

    使用Image中的open方法

    可以看到右边出现了img变量的相关属性 

    如size值即为图片的大小,在控制台中可以对应输出

     

    3. 显示图片:show方法

     调用该方法后可以对应弹出显示图片的窗口 

    4. 获取图片信息列表

    • 引入os库 :import os
    • 获取文件夹相对路径:  dir_path="dataset/train/ants"
    • 获取图片列表:os.listdir函数,可以获取对应文件夹下的所有图片名称的列表

    如图为img_path_list对象,可以看到集合了ants文件夹下所有图片的名称,共124张图片,因此列表大小为124 

     如果访问img_path_list列表的元素,如第一个元素,下标为0,则可以输出第一章图片的名称

    三、完善MyData类

    3.1  初始化方法中需要的参数和方法

    • root_dir:根文件路径,root_dir="dataset/train"
    • label_dir:图片的标签,由于标签名就是文件夹名,因此起名为label_dir,label_dir="ants"
    • os.path.join(x,y)方法 :可以把x和y对应的字符串拼接起来,就可以通过地址拼接访问到想要访问的图片,效果如图所示。
    • os.listdir(path)方法:把对应path下的图片生成图片名称列表

    3.2 初始化init方法的书写

     获取到文件根目录和标签目录后,使用join方法进行地址的拼接,获取到对应图片文件夹的地址,然后使用listdir方法获取到该地址的图片列表

    1. # 重写函数的初始化方法
    2. def __init__(self,root_dir,label_dir):
    3. # 初始化
    4. self.root_dir=root_dir
    5. self.label_dir=label_dir
    6. # 获取图片文件夹的路径
    7. self.path=os.path.join(self.root_dir,self.label_dir)
    8. # 获取对应图片路径的图片名称列表
    9. self.img_list=os.listdir(self.path)

    3.3 getitem方法的书写

    作用:获取到图像列表中单个图片的对象以及其标签

    idx:对应图片的索引值

    使用拼接法:文件夹路径+图片名 可以获取到具体某一张图片的地址 

    open方法生成对应图片对象

    python基础:如果有多个返回值,默认以元组形式打包,因此geitem方法返回的是(img,label)的元组 

    1. # 重写类的getitem方法
    2. def __getitem__(self, idx):
    3. # 获取单个图片名称
    4. img_name=self.img_list[idx]
    5. # 获取单个图片路径,使用拼接法
    6. img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
    7. # 生成对应图片对象
    8. img = Image.open(img_item_path)
    9. # 对应标签
    10. label = self.label_dir
    11. # 返回图像和标签,以元组格式返回
    12. return img,label

    3.4 生成实例 

    1. root_dir="dataset/train"
    2. label_dir="ants"
    3. #实例化MyData类
    4. ants_datasets=MyData(root_dir,label_dir)

    在控制台中进行测试,可以看到生成的ants_datasets对象中有了我们在上面初始化方法中进行定义的所有属性,如list,path等等

     ants_datasets数据集的第一项即为第一张图片对象以及其label标签

    img,label=ants_datasets[1],使用img和label接受元组中的img和label,可以看到变量中img和lable有了对应的具体值

    3.4 两个数据集的生成与相加操作

    1. 生成蚂蚁和蜜蜂数据集 

    1. root_dir="dataset/train"
    2. ants_label_dir="ants"
    3. bees_label_dir="bees"
    4. # 生成MyData类的实例对象
    5. ants_datasets=MyData(root_dir,ants_label_dir)
    6. bees_datasets=MyData(root_dir,bees_label_dir)

    2. 数据集相加

    可以看到相加后train_datasets的长度是两个数据集的和 

     

    四、完整代码

    1. from torch.utils.data import Dataset
    2. # 读取图片
    3. from PIL import Image
    4. # 关于系统的库
    5. import os
    6. class MyData(Dataset):
    7. # 重写函数的初始化方法
    8. def __init__(self,root_dir,label_dir):
    9. # 初始化
    10. self.root_dir=root_dir
    11. self.label_dir=label_dir
    12. # 获取图片文件夹的路径
    13. self.path=os.path.join(self.root_dir,self.label_dir)
    14. # 获取对应图片路径的图片名称列表
    15. self.img_list=os.listdir(self.path)
    16. # 重写类的getitem方法
    17. def __getitem__(self, idx):
    18. # 获取单个图片名称
    19. img_name=self.img_list[idx]
    20. # 获取单个图片路径,使用拼接法
    21. img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
    22. # 生成对应图片对象
    23. img = Image.open(img_item_path)
    24. # 对应标签
    25. label = self.label_dir
    26. # 返回图像和标签,元组
    27. return img,label
    28. def __len__(self):
    29. return len(self.img_list)
    30. root_dir="dataset/train"
    31. ants_label_dir="ants"
    32. bees_label_dir="bees"
    33. # 生成MyData类的实例对象
    34. ants_datasets=MyData(root_dir,ants_label_dir)
    35. bees_datasets=MyData(root_dir,bees_label_dir)
    36. # 两个数据集相加
    37. train_datasets=ants_datasets+bees_datasets

    五、使用修改后数据集的代码练习

    修改后数据集结构如下图所示,图像和标签各有一个文件夹进行存储

     标签文件夹下是各个图像的标签,为txt文件,文件名与图像名相同,并且文件内容仅有一行,即为标签内容ants

     因此获取标签时需要使用file读取文件形式

    1. from torch.utils.data import Dataset
    2. from PIL import Image
    3. import os
    4. class MyDataset(Dataset):
    5. def __init__(self,root_dir,img_dir,label_dir):
    6. # 根文件路径
    7. self.root_dir=root_dir
    8. # 图片文件路径
    9. self.img_dir=img_dir
    10. #标签文件夹路径
    11. self.label_dir=label_dir
    12. # 获取图片文件夹路径并生成图片名称的列表
    13. self.img_path=os.path.join(self.root_dir,self.img_dir)
    14. self.img_list=os.listdir(self.img_path)
    15. #获取标签文件夹路径并生成标签名称的列表
    16. self.label_path=os.path.join(self.root_dir,self.label_dir)
    17. self.label_list=os.listdir(self.label_path)
    18. def __getitem__(self, item):
    19. img_name=self.img_list[item]
    20. img_item_path=os.path.join(self.img_path,img_name)
    21. # 读取对应路径的图片内容,生成图片对象,存储在img中
    22. img=Image.open(img_item_path)
    23. label_name=self.label_list[item]
    24. label_item_path=os.path.join(self.label_path,label_name)
    25. # 打开对应路径的txt文件,读取对应内容,存储在label中
    26. file1 = open(label_item_path,"r")
    27. label= file1.readline()
    28. return img,label
    29. def __len__(self):
    30. return len(self.img_list)
    31. root_dir="datasets2/train"
    32. ants_img_dir="ants_image"
    33. ants_label_dir="ants_label"
    34. bees_img_dir="bees_image"
    35. bees_label_dir="bees_label"
    36. ants_datasets=MyDataset(root_dir,ants_img_dir,ants_label_dir)
    37. bees_datasets=MyDataset(root_dir,bees_img_dir,bees_label_dir)

  • 相关阅读:
    【2022-8-25奇安信算法笔试】偏机器相关
    npm设置淘宝镜像地址
    django+drf+vue 简单系统搭建 (1) - django创建项目
    winfrom .net 6使用EF Core,使用的是Code First代码先行
    CSP-J 2023 第二轮认证入门级(不含答案)
    Yocto Project 编译imx-第1节(下载和编译)
    spark sql保存hive表时的压缩设置
    linux 安装mysql
    pytest自动化框架运行全局配置文件pytest.ini
    三维重建---第一章 摄像机几何
  • 原文地址:https://blog.csdn.net/weixin_45662399/article/details/127386185