目录
表示该对象可迭代,并不一定是一个数据类型,如字典,字符串,列表等,它也可以是一个实现了__iter__方法的类。
- from collections.abc import Iterable, Iterator
-
- class A(object):
- def __init__(self):
- self.a = [1, 2, 3]
-
- def __iter__(self):
- # 此处返回啥无所谓
- return self.a
-
- cls_a = A()
- # True
- print(isinstance(cls_a, Iterable))
如果对象是Iterable,依然无法用for循环遍历,因为Iterable仅仅是提供了一种抽象规范接口。
如果一个对象是迭代器,那么它肯定是可迭代的,但是如果一个对象是可迭代的,它不一定是迭代器。实现了 __next__ 和 __iter__ 方法的类才能称为迭代器,就可以被 for 遍历了。
- class A(object):
- def __init__(self):
- self.index = -1
- self.a = [1, 2, 3]
-
- # 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
- # 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
- def __iter__(self):
- return self
-
- def __next__(self):
- self.index += 1
- if self.index < len(self.a):
- return self.a[self.index]
- else:
- # 抛异常,for 内部会自动捕获,表示迭代完成
- raise StopIteration("遍历完了")
-
- cls_a = A()
- print(isinstance(cls_a, Iterable)) # True
- print(isinstance(cls_a, Iterator)) # True
- print(isinstance(iter(cls_a), Iterator)) # True
-
- for a in cls_a:
- print(a)
- # 打印 1 2 3
for.....in...被python编译器编译后,如下
- # 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
- cls_a = iter(cls_a)
- while True:
- try:
- # 然后调用对象的 __next__ 方法,不断返回元素
- value = next(cls_a)
- print(value)
- # 如果迭代完成,则捕获异常即可
- except StopIteration:
- break
可见,任何一个对象要能被for遍历,必须实现__iter__和__next__两个方法。
list是可迭代对象,但是没next方法,为什么可以实现for循环遍历。list内部的iter方法的内部实现了next方法。
所以得到:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 __iter__ 和 __next__ 方法(对象必然是 Iterator),还是只实现 __iter__(不是 Iterator),但是内部间接返回了具备 __next__ 对象的类,都是可行的。
上面说过for in本质就是调用__iter__和__next__方法,实际上还有一种更简单的方法,__getitem__方法就可以让对象实现迭代功能。实际上任何一个类,只要实现了__getitem__方法,那么当调用iter(类实例)时候会自动具备__iter__和__next__方法。__getitem__ 实际上是属于 iter和next方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。
- class A(object):
- def __init__(self):
- self.a = [1, 2, 3]
-
- def __getitem__(self, item):
- return self.a[item]
-
- cls_a = A()
- print(isinstance(cls_a, Iterable)) # False
- print(isinstance(cls_a, Iterator)) # False
- print(dir(cls_a)) # 仅仅具备 __getitem__ 方法
-
- cls_a = iter(cls_a)
- print(dir(cls_a)) # 具备 __iter__ 和 __next__ 方法
-
- print(isinstance(cls_a, Iterable)) # True
- print(isinstance(cls_a, Iterator)) # True
-
- # 等价于 for .. in ..
- while True:
- try:
- # 然后调用对象的 __next__ 方法,不断返回元素
- value = next(cls_a)
- print(value)
- # 如果迭代完成,则捕获异常即可
- except StopIteration:
- break
-
- # 输出: 1 2 3
如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 __len__ 方法即可。
此时我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。
生成器是一个在行为上和迭代器非常类似的对象,两者功能差不多,但生成器更优雅,只需要用关键字yield来返回。作用于函数上叫生成器函数,调用函数返回一个生成器。
- def func():
- for a in [1, 2, 3]:
- yield a
-
- cls_g = func()
- print(isinstance(cls_g, Iterator)) # True
- print(dir(cls_g)) # 自动具备 __iter__ 和 __next__ 方法
-
- for a in cls_g:
- print(a)
-
- # 输出: 1 2 3
-
- # 一种更简单的写法是用 ()
- cls_g = (i for i in [1,2,3])
使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。
总结:在迭代对象基础上,如果实现了 __next__ 方法则是迭代器对象,该对象在调用 next() 的时 候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。
对于采用语法糖 __getitem__ 实现的迭代器对象,其本身实例既不是可迭代对象,更不是 迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成 迭代器。
生成器是一种特殊迭代器,但是不需要像迭代器一样实现__iter__和__next__方法,只需要 使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入 yield 关键字实现。
对于在类的 __iter__ 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对 象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后 就会自动变成迭代器。
首先介绍5个基本的对象:
Dataset提供整个数据集的随机访问功能,每次访问都返回单个对象,例如一个对象和一个target。
Sampler提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引。常用的子类是SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引
BatchSampler内部调用Sampler实列,输出指定batch_size个索引,然后将索引作用于Dataset上从而输出batch_size个数据对象,例如batch_size个数据和索引。
Collate_fn用于将batch_size个数据对象在batch维度进行聚合,生成(batch,.....)格式的数据输出。如果待聚合对象是numpy,则自动转化为tensor,此时就可以输入到网络中了。
迭代一次伪代码如下(非迭代器版本)
- class DataLoader(object):
- def __init__(self):
- #假设数据长度为100,batch_size是4
- self.dataset=[[img0,target0],[img1,target1],.....[img99,target99]]
- self.sampler=[0,1,2,.....,99]
- self.batch_size=4
- self.index=0
-
- def collate_fn(self,data):
- #在batch维度聚合数据
- batch_img=torch.Stack(data[0],0)
- batch_target=torch.stack(data[1],0)
- return batch_img,batch_target
-
- def __next__(self):
- i=0
- batch_index=[]
- while i
- #内部会调用sampler对象获取单个索引
- batch_index.append(self.sampler[self.index])
- self.index+=1
- i+=1
- #得到batch_size个索引之后,调用dataset对象
- data=[self.dataset[idx] for idx in batch_index]
- #调用collate_fn 在batch维度进行拼接输出
- batch_data=self.collate_fn(data)
- return batch_data
-
- def __iter__(self):
- return self
- # torch.stack()是指将列表里面的张量进行扩维拼接
- # data=[[torch.Tensor([1]),torch.Tensor([1])],[torch.Tensor([1]),torch.Tensor([1])]]
- # print(torch.stack(data[0],0),torch.stack(data[1],0))
- # data=[torch.Tensor([1,2,3]),torch.Tensor([4,5,6])]
- # print(torch.stack(data))
以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。
首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器。
Dataset通过__getitem__方法使其可迭代
Sample对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 __iter__ 内部返回迭代器,RandomSampler 在 __iter__ 内部通过 yield 关键字返回迭代器
Batchsampler也是在__iter__内部通过yield关键字返回迭代器
DataLoader通过__iter__和__next__直接实现迭代器
除了DataLoader本身是迭代器外,其余对象本身都不是迭代器,但可以for in迭代
由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。
需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。
三、整体框架的讲解
核心运行逻辑:
- def __next__(self):
- #返回batch个索引
- index=next(self.batch_sampler)
- #利用索引去取数据
- data=[self.dataset[idx] for idx in index]
- #batch维度聚合
- data=self.collate_fn(data)
- return data
整体流程:
1.self.batch_sampler=iter(batch_sampler)。在DataLoader的类初始化,需要得到BatchSampler的迭代器对象。
2.index=next(self.batch_sampler)。对于每次迭代,DataLoader对象首先会调用BatchSampler的迭代器进行下一次迭代,具体是调用BatchSampler对象的__iter__方法
3.而BatchSampler对象的__iter__方法实际上是需要依靠Sampler对象进行迭代输出索引,Sampler对象也是一个迭代器,当迭代batch_size次后就可以得到batch_size个数据索引。
4.data=[self.dataset[idx] for idx in index]。有了batch个索引就可以通过不断调用dataset的__getitem__方法返回数据对象,此时data就包含了batch个对象。
5.data=self.collate_fn(data)。将batch个对象输入给聚合函数,在第0个维度也就是batch维度进行聚合,得到类似(batch,....)的对象。
6.重复上面的操作,就可以不断输出一个一个的batch数据
- class Dataset(object):
- #只要实现了__getitem__方法就可以变成迭代器
- def __getitem__(self,index):
- raise NotImplementedError
- def __len__(self):
- raise NotImplementedError
- class Sampler(object):
- def __init__(self,data_source):
- pass
- def __iter__(self):
- raise NotImplementedError
- def __len__(self):
- raise NotImplementedError
- #一般出现raise NotImplementedError这个错误,就是子类没有重写父类中的成员函数,然后子类对象调用此函数会报这个错误
-
- class SequentialSampler(sampler):
- def __init__(self,data_source):
- super(SequentialSampler,self).__init__(data_source)
- self.data_source=data_source
- def __iter__(self):
- #返回迭代器,不然无法for in
- return iter(range(len(self.data_source))
- def __len__(self):
- return len(self.data_source)
-
- class BatchSampler(Sampler):
- def __init__(self,sampler,batch_size,drop_last):
- self.sampler=sampler
- self.batch_size=batch_size
- self.dorp_last=drop_last
-
- def __iter__(self):
- batch=[]
- for idx in self.sampler:
- batch.append(idx)
- #如果得到了batch个索引,则可以通过yield关键字生成生成器返回,得到迭代器对象
- if len(batch)==self.batch_size:
- yield batch
- batch=[]
- if len(batch)>0 and not self.drop_last:
- yield batch
- def __len__(self):
- if self.drop_last:
- #如果最后的索引数不等于一个batch,抛弃
- return len(self.sampler)//self.batch_size
- else:
- return (len(self.sampler)+self.batch_size-1)//self.batch_size
- class DataLoader(object):
- def __init__(self,dataset,batch_size=1,shuffle=False,sample=None,batch_sampler=None,
- collate_fn=None,drop_last=False):
- self.dataset=dataset
- #因为这两个功能是冲突的
- if sampler is not None and shuffle:
- raise ValueError('sampler option is ..')
- if batch_sampler is not None:
- # 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler
- # 和 drop_last 四个参数就不能传入
- # 因为这4个参数功能和 batch_sampler 功能冲突了
- if batch_size != 1 or shuffle or sampler is not None or drop_last:
- raise ValueError('batch_sampler option is mutually exclusive '
- 'with batch_size, shuffle, sampler, and '
- 'drop_last')
- batch_size = None
- drop_last = False
- if sampler is None:
- if shuffle:
- sampler = RandomSampler(dataset)
- else:
- sampler = SequentialSampler(dataset)
- # 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类
- if batch_sampler is None:
- batch_sampler = BatchSampler(sampler, batch_size, drop_last)
-
- self.batch_size = batch_size
- self.drop_last = drop_last
- self.sampler = sampler
- self.batch_sampler = iter(batch_sampler)
-
- if collate_fn is None:
- collate_fn = default_collate
- self.collate_fn = collate_fn
-
- #核心代码
- def __next__(self):
- index=next(self.batch_sampler)
- data=[self.dataset[idx] for idx in index]
- data=self.collate_fn(data)
- return data
- #返回自身,因为自身实现了next
- def __iter__(self):
- return self
- def default_collate(batch):
- elem=batch[0]
- elem_type=type(elem)
- if isinstance(elem,torch.Tensor):
- return torch.stack(batch,0)
- elif elem_type.__module__=='numpy':
- return default_collate([torch.as_tensor(b) for b in batch])
- else:
- raise NotImplementedError
完整调用例子
- class Simplev1Dataset(Dataset):
- def __init__(self):
- #伪造数据
- self.imgs=np.arange(0,16).reshape(8,2)
-
- def __getitem__(self,index):
- return self.imgs[index]
-
- def __len__(self):
- return self.imgs.shape[0]
-
- from simplev1_dataset import Simplev1Dataset
- simple_dataset=Simplev1Dataset()
- dataloader=DataLoader(simple_dataset,batch_size=2,collate_fn=default_collate)
- for data in dataloader:
- print(data)
四、Reference