• pytorch中的Dataloader和dataset详细的collate_fn参数测试


    DataLoder的参数

    参考:https://blog.csdn.net/weixin_43794311/article/details/118091799

    使用简单介绍

    一、首先需要导入库,下面两种方式都行

    from torch.utils.data import *
    from torch.utils.data import DataLoader,Dataset
    
    • 1
    • 2

    二、先建立自己的Dataset子类

    class my_Dataset(Dataset):
    	def __init__(self, 想要传入的参数):
    		#参数一般是路径
    		#对属性的赋值,一般得到所有的数据路径
    
        def __len__(self):
            return len(self.img_paths)#返回数据加载的数量
    
        def __getitem__(self, index):  # 对每个图片进行处理
        	#对每个加载的内容进行处理,最后返回需要使用的内容
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    三、定义DataLoader中参数collate_fn
    这一步可以省略,但只能按照默认的格式输出,假设定义的DataSet的return中的返回两个对象

    def collate_fn(batch):
    	renew_out=[]
    	for item in batch:#对一个batchsize的数据进行循环遍历后,控制输出
    		el,el1 = item #将返回的两个对象进行重新处理格式
    		renew_out+=[el,el1]
    
    • 1
    • 2
    • 3
    • 4
    • 5

    四、使用DataLoader加载
    1、先实例化一个自己定义的Dataset对象,定义需要的数据
    2、使用DataLoader加载生成需要的数据,其中设置了加载的线程,是否打乱,每个批数量等

    data_set_object = my_Dataset(需要的参数)  # 先实例化一个
    data_loader = DataLoader(data_set_object,batch_size,num_work,collate_fn,shuffle)
    
    • 1
    • 2

    使用collate_fn和未使用自定义的不同

    一、未使用collate_fn时

    from imutils import paths
    from torch.utils.data import *
    import matplotlib.pyplot as plt
    import cv2
    import numpy as np
    
    def collate_fn(batch):
    	filenames=[]
    	heights = []
    	back_out = []
    	for filename,height in batch:
    		#print('file_name:',filename)
    		#print('height',height)
    		back_out+=[filename,height]
    		filenames.append(filename)
    	return back_out,filenames
    	
    class My_loader(Dataset):
        def __init__(self, img_dir):
    
            self.img_dir = img_dir
            self.img_paths = []
            self.img_paths += [el for el in paths.list_images(img_dir)]
    
        def __len__(self):
            return len(self.img_paths)
    
        def __getitem__(self, index):  # 对每个图片进行处理
            filename = self.img_paths[index]
            # Image = cv2.imread(filename) # 原始读取
            plt_img = plt.imread(filename)
            # 为了正确显示plt和cv2图片矩阵格式修改
            Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
    
            height, width, _ = Image.shape  # h w c
            return filename,height  #自己定义的DataSet的返回
    
    # # 若参数路径错误,可能出现
    # ValueError: num_samples should be a positive integer value, but got num_samples=0
    # 
    if __name__ == '__main__':
    	train_dataset = My_loader('my_test_imgs')
    	print(len(train_dataset))  
    	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,)
    	print(len(my_dataloder))
    	for batch in my_dataloder:
    		img_dir,height = batch #这里的DataSet对象的返回
    		print(batch)
    		print("!"*40)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    在这里插入图片描述

    二、使用collate_fn函数后

    from imutils import paths
    from torch.utils.data import *
    import matplotlib.pyplot as plt
    import cv2
    
    import numpy as np
    
    def collate_fn(batch):
    	filenames=[]
    	heights = []
    	back_out = []
    	for filename,height in batch:
    		# print('file_name:',filename)
    		# print('height',height)
    		back_out+=[filename,height]  # 相加依然放在一个列表中,和后面的collate_fn比较
    		filenames.append(filename)
    	return back_out
    
    
    class My_loader(Dataset):
        def __init__(self, img_dir):
    
            self.img_dir = img_dir
            self.img_paths = []
            self.img_paths += [el for el in paths.list_images(img_dir)]
    
        def __len__(self):
            return len(self.img_paths)
    
        def __getitem__(self, index):  # 对每个图片进行处理
            filename = self.img_paths[index]
            # Image = cv2.imread(filename) # 原始读取
            plt_img = plt.imread(filename)
            # 为了正确显示plt和cv2图片矩阵格式修改
            Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
    
            height, width, _ = Image.shape  # h w c
            return filename,height
    
    # # 若参数路径错误,可能出现
    # ValueError: num_samples should be a positive integer value, but got num_samples=0
    # 
    if __name__ == '__main__':
    	train_dataset = My_loader('my_test_imgs')
    	print(len(train_dataset))  
    	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
    	print(len(my_dataloder))
    	for batch in my_dataloder:
    		# img_dir,height = batch
    		print(batch)
    		print("!"*40)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    在这里插入图片描述
    如果修改collate_fn,结果就会改变,修改back_out.append([filename,height])

    def collate_fn(batch):
    	filenames=[]
    	heights = []
    	back_out = []
    	for filename,height in batch:
    		# print('file_name:',filename)
    		# print('height',height)
    		back_out.append([filename,height]) # 将一个列表作为最小的单位扩展放入空表中
    	return back_out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    在这里插入图片描述

    DataLoader中运行情况

    1.先执行main中的print
    2.轮流执行DataSet和collate_fn函数中的print,执行DataSet的次数是batchsize的次数
    
    • 1
    • 2
    from imutils import paths
    from torch.utils.data import *
    import matplotlib.pyplot as plt
    import cv2
    
    import numpy as np
    
    def collate_fn(batch):
    	filenames=[]
    	heights = []
    	back_out = []
    	for filename,height in batch:
    		# print('file_name:',filename)
    		# print('height',height)
    		back_out.append([filename,height])
    	print('这是collatet_fn中的运行')
    	return back_out
    
    class My_loader(Dataset):
        def __init__(self, img_dir):
    
            self.img_dir = img_dir
            self.img_paths = []
            self.img_paths += [el for el in paths.list_images(img_dir)]
        def __len__(self):
            return len(self.img_paths)
    
        def __getitem__(self, index):  # 对每个图片进行处理
            filename = self.img_paths[index]
            # Image = cv2.imread(filename) # 原始读取
            plt_img = plt.imread(filename)
            # 为了正确显示plt和cv2图片矩阵格式修改
            Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
    
            height, width, _ = Image.shape  # h w c
            print("这是DataSet中的内容")
            return filename,height
    
    # # 若参数路径错误,可能出现
    # ValueError: num_samples should be a positive integer value, but got num_samples=0
    # 
    if __name__ == '__main__':
    	train_dataset = My_loader('my_test_imgs')
    	print(len(train_dataset))  
    	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
    	print(len(my_dataloder))
    	for batch in my_dataloder:
    		# img_dir,height = batch
    		print(batch)
    		print("这是main中")
    		print("!"*40)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    结果分析:先执行了main中,然后自定义的DataSet中__getitem__()和调整输出函数collate_fn中的内容交替执行。
    在这里插入图片描述

    注意:自定义的DataSet中的__init__()

    这个函数只在实例化的时候执行一次

    在对DataLoader的对象进行循环访问时出现问题

    发现问题的过程,只是将程序复制一份后出现下面问题;又对旧文件及环境进行了测试,依然正常;
    最后解决:发现唯一不同的是pycharm中的解释器环境版本不同,修改了新的环境和旧的一样后正常显示
    分析可能的原因:问题中是因为循环Dataloader中导致的线程出现问题,旧和新的torch是同一个版本,但导致不同结果,可能是其他库版本的问题

    RuntimeError: Caught RuntimeError in DataLoader worker process 0
    RuntimeError: Could not infer dtype of numpy.float32
    
    • 1
    • 2
  • 相关阅读:
    51.MongoDB聚合操作与索引使用详解
    C语言二维数组编程练习集
    2023年Java核心技术大会(Core Java Week 2023)-核心PPT资料下载
    新浪财经行情中心的对象 Market_Center
    vue3表单参数校验+正则表达式
    JavaScript实现字体大小调整
    Elasticsearch7.15.2 安装ik中文分词器后启动ES服务报错的解决办法
    9、鸿蒙应用桌面图标外观和国际化
    @PostConstruct详解
    Machine Learning学习(一)Overview of machine learning机器学习概述
  • 原文地址:https://blog.csdn.net/weixin_43794311/article/details/122079125