• 关于pytorch里DataLoader的理解


    目录

    一、python迭代器生成器基础讲解

    1.1可迭代对象Iterable

    1.2迭代器Iterator

    1.3for in 的本质流程

    1.4 getitem

    1.5 yield 生成器

    二、DataLoader的基础实现

    三、整体框架的讲解


    一、python迭代器生成器基础讲解

    1.1可迭代对象Iterable

    表示该对象可迭代,并不一定是一个数据类型,如字典,字符串,列表等,它也可以是一个实现了__iter__方法的类。

    1. from collections.abc import Iterable, Iterator
    2. class A(object):
    3. def __init__(self):
    4. self.a = [1, 2, 3]
    5. def __iter__(self):
    6. # 此处返回啥无所谓
    7. return self.a
    8. cls_a = A()
    9. # True
    10. print(isinstance(cls_a, Iterable))

    如果对象是Iterable,依然无法用for循环遍历,因为Iterable仅仅是提供了一种抽象规范接口。

    1.2迭代器Iterator

    如果一个对象是迭代器,那么它肯定是可迭代的,但是如果一个对象是可迭代的,它不一定是迭代器。实现了 __next__ 和 __iter__ 方法的类才能称为迭代器,就可以被 for 遍历了。

    1. class A(object):
    2. def __init__(self):
    3. self.index = -1
    4. self.a = [1, 2, 3]
    5. # 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
    6. # 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
    7. def __iter__(self):
    8. return self
    9. def __next__(self):
    10. self.index += 1
    11. if self.index < len(self.a):
    12. return self.a[self.index]
    13. else:
    14. # 抛异常,for 内部会自动捕获,表示迭代完成
    15. raise StopIteration("遍历完了")
    16. cls_a = A()
    17. print(isinstance(cls_a, Iterable)) # True
    18. print(isinstance(cls_a, Iterator)) # True
    19. print(isinstance(iter(cls_a), Iterator)) # True
    20. for a in cls_a:
    21. print(a)
    22. # 打印 1 2 3

    1.3for in 的本质流程

    for.....in...被python编译器编译后,如下

    1. # 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
    2. cls_a = iter(cls_a)
    3. while True:
    4. try:
    5. # 然后调用对象的 __next__ 方法,不断返回元素
    6. value = next(cls_a)
    7. print(value)
    8. # 如果迭代完成,则捕获异常即可
    9. except StopIteration:
    10. break

    可见,任何一个对象要能被for遍历,必须实现__iter__和__next__两个方法。

    list是可迭代对象,但是没next方法,为什么可以实现for循环遍历。list内部的iter方法的内部实现了next方法。

    所以得到:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 __iter__ 和 __next__ 方法(对象必然是 Iterator),还是只实现 __iter__(不是 Iterator),但是内部间接返回了具备 __next__ 对象的类,都是可行的

    1.4 getitem

    上面说过for in本质就是调用__iter__和__next__方法,实际上还有一种更简单的方法,__getitem__方法就可以让对象实现迭代功能。实际上任何一个类,只要实现了__getitem__方法,那么当调用iter(类实例)时候会自动具备__iter__和__next__方法。__getitem__ 实际上是属于 iternext方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。

    1. class A(object):
    2. def __init__(self):
    3. self.a = [1, 2, 3]
    4. def __getitem__(self, item):
    5. return self.a[item]
    6. cls_a = A()
    7. print(isinstance(cls_a, Iterable)) # False
    8. print(isinstance(cls_a, Iterator)) # False
    9. print(dir(cls_a)) # 仅仅具备 __getitem__ 方法
    10. cls_a = iter(cls_a)
    11. print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
    12. print(isinstance(cls_a, Iterable)) # True
    13. print(isinstance(cls_a, Iterator)) # True
    14. # 等价于 for .. in ..
    15. while True:
    16. try:
    17. # 然后调用对象的 __next__ 方法,不断返回元素
    18. value = next(cls_a)
    19. print(value)
    20. # 如果迭代完成,则捕获异常即可
    21. except StopIteration:
    22. break
    23. # 输出: 1 2 3

    如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 __len__ 方法即可。

    此时我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

    1.5 yield 生成器

    生成器是一个在行为上和迭代器非常类似的对象,两者功能差不多,但生成器更优雅,只需要用关键字yield来返回。作用于函数上叫生成器函数,调用函数返回一个生成器。

    1. def func():
    2. for a in [1, 2, 3]:
    3. yield a
    4. cls_g = func()
    5. print(isinstance(cls_g, Iterator)) # True
    6. print(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
    7. for a in cls_g:
    8. print(a)
    9. # 输出: 1 2 3
    10. # 一种更简单的写法是用 ()
    11. cls_g = (i for i in [1,2,3])

    使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。

    总结:在迭代对象基础上,如果实现了 __next__ 方法则是迭代器对象,该对象在调用 next() 的时             候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

               对于采用语法糖 __getitem__ 实现的迭代器对象,其本身实例既不是可迭代对象,更不是               迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成             迭代器。

              生成器是一种特殊迭代器,但是不需要像迭代器一样实现__iter____next__方法,只需要            使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入            yield 关键字实现。

              对于在类的 __iter__ 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对              象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后            就会自动变成迭代器。

    二、DataLoader的基础实现

    首先介绍5个基本的对象:

    Dataset提供整个数据集的随机访问功能,每次访问都返回单个对象,例如一个对象和一个target。

    Sampler提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引。常用的子类是SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

    BatchSampler内部调用Sampler实列,输出指定batch_size个索引,然后将索引作用于Dataset上从而输出batch_size个数据对象,例如batch_size个数据和索引。

    Collate_fn用于将batch_size个数据对象在batch维度进行聚合,生成(batch,.....)格式的数据输出。如果待聚合对象是numpy,则自动转化为tensor,此时就可以输入到网络中了。

    迭代一次伪代码如下(非迭代器版本)

    1. class DataLoader(object):
    2. def __init__(self):
    3. #假设数据长度为100,batch_size是4
    4. self.dataset=[[img0,target0],[img1,target1],.....[img99,target99]]
    5. self.sampler=[0,1,2,.....,99]
    6. self.batch_size=4
    7. self.index=0
    8. def collate_fn(self,data):
    9. #在batch维度聚合数据
    10. batch_img=torch.Stack(data[0],0)
    11. batch_target=torch.stack(data[1],0)
    12. return batch_img,batch_target
    13. def __next__(self):
    14. i=0
    15. batch_index=[]
    16. while i
    17. #内部会调用sampler对象获取单个索引
    18. batch_index.append(self.sampler[self.index])
    19. self.index+=1
    20. i+=1
    21. #得到batch_size个索引之后,调用dataset对象
    22. data=[self.dataset[idx] for idx in batch_index]
    23. #调用collate_fn 在batch维度进行拼接输出
    24. batch_data=self.collate_fn(data)
    25. return batch_data
    26. def __iter__(self):
    27. return self
    28. # torch.stack()是指将列表里面的张量进行扩维拼接
    29. # data=[[torch.Tensor([1]),torch.Tensor([1])],[torch.Tensor([1]),torch.Tensor([1])]]
    30. # print(torch.stack(data[0],0),torch.stack(data[1],0))
    31. # data=[torch.Tensor([1,2,3]),torch.Tensor([4,5,6])]
    32. # print(torch.stack(data))

    以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

    首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器。

    Dataset通过__getitem__方法使其可迭代

    Sample对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 __iter__ 内部返回迭代器,RandomSampler 在 __iter__ 内部通过 yield 关键字返回迭代器

    Batchsampler也是在__iter__内部通过yield关键字返回迭代器

    DataLoader通过__iter__和__next__直接实现迭代器

    除了DataLoader本身是迭代器外,其余对象本身都不是迭代器,但可以for in迭代

    由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

    需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

    三、整体框架的讲解

    核心运行逻辑:

    1. def __next__(self):
    2. #返回batch个索引
    3. index=next(self.batch_sampler)
    4. #利用索引去取数据
    5. data=[self.dataset[idx] for idx in index]
    6. #batch维度聚合
    7. data=self.collate_fn(data)
    8. return data

    整体流程:

    1.self.batch_sampler=iter(batch_sampler)。在DataLoader的类初始化,需要得到BatchSampler的迭代器对象。

    2.index=next(self.batch_sampler)。对于每次迭代,DataLoader对象首先会调用BatchSampler的迭代器进行下一次迭代,具体是调用BatchSampler对象的__iter__方法

    3.而BatchSampler对象的__iter__方法实际上是需要依靠Sampler对象进行迭代输出索引,Sampler对象也是一个迭代器,当迭代batch_size次后就可以得到batch_size个数据索引。

    4.data=[self.dataset[idx] for idx in index]。有了batch个索引就可以通过不断调用dataset的__getitem__方法返回数据对象,此时data就包含了batch个对象。

    5.data=self.collate_fn(data)。将batch个对象输入给聚合函数,在第0个维度也就是batch维度进行聚合,得到类似(batch,....)的对象。

    6.重复上面的操作,就可以不断输出一个一个的batch数据

    1. class Dataset(object):
    2. #只要实现了__getitem__方法就可以变成迭代器
    3. def __getitem__(self,index):
    4. raise NotImplementedError
    5. def __len__(self):
    6. raise NotImplementedError
    1. class Sampler(object):
    2. def __init__(self,data_source):
    3. pass
    4. def __iter__(self):
    5. raise NotImplementedError
    6. def __len__(self):
    7. raise NotImplementedError
    8. #一般出现raise NotImplementedError这个错误,就是子类没有重写父类中的成员函数,然后子类对象调用此函数会报这个错误
    9. class SequentialSampler(sampler):
    10. def __init__(self,data_source):
    11. super(SequentialSampler,self).__init__(data_source)
    12. self.data_source=data_source
    13. def __iter__(self):
    14. #返回迭代器,不然无法for in
    15. return iter(range(len(self.data_source))
    16. def __len__(self):
    17. return len(self.data_source)
    18. class BatchSampler(Sampler):
    19. def __init__(self,sampler,batch_size,drop_last):
    20. self.sampler=sampler
    21. self.batch_size=batch_size
    22. self.dorp_last=drop_last
    23. def __iter__(self):
    24. batch=[]
    25. for idx in self.sampler:
    26. batch.append(idx)
    27. #如果得到了batch个索引,则可以通过yield关键字生成生成器返回,得到迭代器对象
    28. if len(batch)==self.batch_size:
    29. yield batch
    30. batch=[]
    31. if len(batch)>0 and not self.drop_last:
    32. yield batch
    33. def __len__(self):
    34. if self.drop_last:
    35. #如果最后的索引数不等于一个batch,抛弃
    36. return len(self.sampler)//self.batch_size
    37. else:
    38. return (len(self.sampler)+self.batch_size-1)//self.batch_size
    1. class DataLoader(object):
    2. def __init__(self,dataset,batch_size=1,shuffle=False,sample=None,batch_sampler=None,
    3. collate_fn=None,drop_last=False):
    4. self.dataset=dataset
    5. #因为这两个功能是冲突的
    6. if sampler is not None and shuffle:
    7. raise ValueError('sampler option is ..')
    8. if batch_sampler is not None:
    9. # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
    10. # 和 drop_last 四个参数就不能传入
    11. # 因为这4个参数功能和 batch_sampler 功能冲突了
    12. if batch_size != 1 or shuffle or sampler is not None or drop_last:
    13. raise ValueError('batch_sampler option is mutually exclusive '
    14. 'with batch_size, shuffle, sampler, and '
    15. 'drop_last')
    16. batch_size = None
    17. drop_last = False
    18. if sampler is None:
    19. if shuffle:
    20. sampler = RandomSampler(dataset)
    21. else:
    22. sampler = SequentialSampler(dataset)
    23. # 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
    24. if batch_sampler is None:
    25. batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    26. self.batch_size = batch_size
    27. self.drop_last = drop_last
    28. self.sampler = sampler
    29. self.batch_sampler = iter(batch_sampler)
    30. if collate_fn is None:
    31. collate_fn = default_collate
    32. self.collate_fn = collate_fn
    33. #核心代码
    34. def __next__(self):
    35. index=next(self.batch_sampler)
    36. data=[self.dataset[idx] for idx in index]
    37. data=self.collate_fn(data)
    38. return data
    39. #返回自身,因为自身实现了next
    40. def __iter__(self):
    41. return self
    1. def default_collate(batch):
    2. elem=batch[0]
    3. elem_type=type(elem)
    4. if isinstance(elem,torch.Tensor):
    5. return torch.stack(batch,0)
    6. elif elem_type.__module__=='numpy':
    7. return default_collate([torch.as_tensor(b) for b in batch])
    8. else:
    9. raise NotImplementedError

    完整调用例子

    1. class Simplev1Dataset(Dataset):
    2. def __init__(self):
    3. #伪造数据
    4. self.imgs=np.arange(0,16).reshape(8,2)
    5. def __getitem__(self,index):
    6. return self.imgs[index]
    7. def __len__(self):
    8. return self.imgs.shape[0]
    9. from simplev1_dataset import Simplev1Dataset
    10. simple_dataset=Simplev1Dataset()
    11. dataloader=DataLoader(simple_dataset,batch_size=2,collate_fn=default_collate)
    12. for data in dataloader:
    13. print(data)

    四、Reference

    https://zhuanlan.zhihu.com/p/340465632

  • 相关阅读:
    物联网AI MicroPython传感器学习 之 WS2812 RGB点阵灯环
    深度学习之激活函数——Leaky ReLU
    基于若依系统开发一套自己的移动端的框架
    springboot整合
    Android : 页面之间的数据传递 intent+bundle
    java计算机毕业设计企业信息安全评价系统源码+系统+mysql数据库+lw文档
    MySQL数据库同时查询更新同一张表的方法
    线程和线程池
    【poi导出excel模板——通过建造者模式+策略模式+函数式接口实现】
    QML控件类型:TabBar
  • 原文地址:https://blog.csdn.net/slamer111/article/details/127927782