• PyTorch 源码解读之 torch.utils.data:解析数据处理全流程(非常好,一篇足够)



    本篇博文主要用来记录参考链接中的所学重要知识,梳理清楚。

    1 Dataset

    Dataset 负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。

    • Map-style sataset
      torch.utils.data.Dataset
      它是一种通过实现 getitem() 和 len() 来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。访问时,这样的数据集用 dataset[idx] 访问 idx 对应的数据。
    • Iterable-style dataset
    • 其他Dataset
      torch.utils.data.TensorDataset: 用于获取封装成 tensor 的数据集,每一个样本都通过索引张量来获得。

    2 Sampler

    torch.utils.data.Sampler 负责提供一种遍历数据集所有元素索引的方式。
    特别地,len() 方法不是必要的,但是当 DataLoader 需要计算 len() 的时候必须定义,这点在其源码中也有注释加以体现。

    同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类

    torch.utils.data.SequentialSampler : 顺序采样样本,始终按照同一个顺序
    torch.utils.data.RandomSampler: 可指定有无放回地,进行随机采样样本元素
    torch.utils.data.BatchSampler: 在一个batch中封装一个其他的采样器, 返回一个 batch 大小的 index 索引

    3 DataLoader

    torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以设置 loading order, batch size, pin memory 等加载参数。其接口定义如下:

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
               batch_sampler=None, num_workers=0, collate_fn=None,
               pin_memory=False, drop_last=False, timeout=0,
               worker_init_fn=None, *, prefetch_factor=2,
               persistent_workers=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    对于每个参数的含义,以下给出一个表格进行对应介绍:

    attributemeaningdefault valuetype
    dataset加载数据的数据集Dataset
    batch_size每个 batch 加载多少个样本1int
    shuffle设置为 True 时,调用 RandomSampler 进行随机索引Falsebool
    sampler定义从数据集中提取样本的策略。如果指定了, shuffle 参数必须为 False,(否则会和 RandomSampler 互斥)NoneSampler, Iterable
    batch_sampler和 sampler 类似,但是一般传入 BatchSampler,每次返回一个 batch 大小的索引。其和 batch_size, shuffle 等参数是互斥的NoneSampler, Iterable
    num_workers要用于数据加载的子进程数,0 表示将在主进程中加载数据0int
    collate_fn在将 Map-style dataset 取出的数据整合成 batch 时使用,合并样本列表以形成一个 batchNonecallable
    pin_memory如果为 True,则 DataLoader 在将张量返回之前将其复制到 CUDA 固定的内存中Falsebool
    drop_last设置为 True 删除最后一个不完整的批次,如果该数据集大小不能被该批次大小整除。如果 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小Falsebool
    timeout如果为正,则为从 worker 收集 batch 的超时值,应始终为非负数。超过这个时间还没读取到数据的话就会报错0numeric
    worker_init_fn如果不为 None,它将会被每个 worker 子进程调用,以 worker id ([0, num_workers - 1] 内的整形) 为输入Nonecallable
    prefetch_facto每个 worker 提前加载 的 sample 数量2int
    persistent_workers如果为 True,dataloader 将不会终止 worker 进程,直到 dataset 迭代完成Falsebool

    3.1 三者关系 (Dataset, Sampler, Dataloader)

    • 设置 Dataset,将数据 data source 包装成 Dataset 类,暴露提取接口。
    • 设置 Sampler,决定采样方式。我们是能从 Dataset 中提取元素了,还是需要设置 Sampler 告诉程序提取 Dataset 的策略。
    • 将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle, batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。

    总结来说,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。

    3.2 批处理

    3.2.1 自动批处理(默认)

    DataLoader 支持通过参数batch_size, drop_last, batch_sampler,自动地把取出的数据整理 (collate) 成批次样本 (batch)

    batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample参数,一次就生成一个 keys list

    在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象这个过程,其表示方式大致如下

    # For Map-style
    for indices in batch_sampler:
        yield collate_fn([dataset[i] for i in indices])
    
    # For Iterable-style
    dataset_iter = iter(dataset)
    for indices in batch_sampler:
        yield collate_fn([next(dataset_iter) for _ in indices])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    3.2.3 collate_fn

    当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。

    当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情

    • 添加新的批次维度(一般是第一维)
    • 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量
    • 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或list,当不能转换的时候)。list, tuples, namedtuples 同样适用

    自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。

    3.3 多进程处理 (multi-process)

    为了避免在加载数据时阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数设置 num_workers 为正整数即可执行多进程数据加载,设置为 0 时执行单线程数据加载。

    4 预取 (prefetch)

    DataLoader 通过指定 prefetch_factor (默认为 2)来进行数据的预取。

    class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
        def __init__(self, loader):
            ...
            self._reset(loader, first_iter=True)
    
        def _reset(self, loader, first_iter=False):
            ...
            # prime the prefetch loop
            for _ in range(self._prefetch_factor * self._num_workers):
                self._try_put_index()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    通过源码可以看到,prefetch 功能仅适用于 多进程 加载中(下面会由多进程 dataloader 的代码分析)

    5 代码详解

    来看看具体的代码调用流程:

    for data, label in train_loader:
        ......
    
    • 1
    • 2

    for 循环会调用 dataloader 的 iter(self) 方法,以此获得迭代器来遍历 dataset

    class DataLoader(Generic[T_co]):
        ...
        def __iter__(self) -> '_BaseDataLoaderIter':
    
            if self.persistent_workers and self.num_workers > 0:
                if self._iterator is None:
                    self._iterator = self._get_iterator()
                else:
                    self._iterator._reset(self)
                return self._iterator
            else:
                return self._get_iterator()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    iter(self) 方法中,dataloader 调用了self._get_iterator() 方法,根据 num_worker 获得迭代器,并指示进行单进程还是多进程

    class DataLoader(Generic[T_co]):
        ...
        def _get_iterator(self) -> '_BaseDataLoaderIter':
            if self.num_workers == 0:
                return _SingleProcessDataLoaderIter(self)
            else:
                self.check_worker_number_rationality()
                return _MultiProcessingDataLoaderIter(self)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    为了描述清晰,我们只考虑单进程的代码。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) ,以及其父类 class _BaseDataLoaderIter(object): 的重点代码片段:

    class _BaseDataLoaderIter(object):
        def __init__(self, loader: DataLoader) -> None:
            # 初始化赋值一些 DataLoader 参数,
            # 以及用户输入合法性进行校验
            self._dataset = loader.dataset
            self._dataset_kind = loader._dataset_kind
            self._index_sampler = loader._index_sampler
            ...
    
        def __iter__(self) -> '_BaseDataLoaderIter':
            return self
    
        def _reset(self, loader, first_iter=False):
            self._sampler_iter = iter(self._index_sampler)
            self._num_yielded = 0
            self._IterableDataset_len_called = loader._IterableDataset_len_called
    
        def _next_index(self):
            return next(self._sampler_iter)  # may raise StopIteration
    
        def _next_data(self):
            raise NotImplementedError
    
        def __next__(self) -> Any:
            with torch.autograd.profiler.record_function(self._profile_name):
                if self._sampler_iter is None:
                    self._reset()
                data = self._next_data() # 重点代码行,通过此获取数据
                self._num_yielded += 1
                ...
                return data
    
        next = __next__  # Python 2 compatibility
    
        def __len__(self) -> int:
            return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
    
        def __getstate__(self):
            raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    _BaseDataLoaderIter 是所有 DataLoaderIter 的父类。dataloader获得了迭代器之后,for 循环需要调用 __next__() 来获得下一个对象,从而实现遍历。通过 __next__ 方法调用 _next_data() 获取数据

    class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
        def __init__(self, loader):
            super(_SingleProcessDataLoaderIter, self).__init__(loader)
            assert self._timeout == 0
            assert self._num_workers == 0
    
            self._dataset_fetcher = _DatasetKind.create_fetcher(
                self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    
        def _next_data(self):
            index = self._next_index()  # may raise StopIteration
            data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            if self._pin_memory:
                data = _utils.pin_memory.pin_memory(data)
            return data
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    从 _SingleProcessDataLoaderIter 的初始化参数可以看到,其在父类 _BaseDataLoaderIter 的基础上定义了 _dataset_fetcher, 并传入 _dataset, _auto_collation, _collate_fn 等参数,用于定义获取数据的方式。其具体实现会在稍后解释。

    在 _next_data() 被调用后,其需要 next_index() 获取 index,并通过获得的 index 传入 _dataset_fetcher 中获取对应样本

    class DataLoader(Generic[T_co]):
        ...
        @property
        def _auto_collation(self):
            return self.batch_sampler is not None
    
        @property
        def _index_sampler(self):
            if self._auto_collation:
                return self.batch_sampler
            else:
                return self.sampler
    
    class _BaseDataLoaderIter(object):
        ...
        def _reset(self, loader, first_iter=False):
            self._sampler_iter = iter(self._index_sampler)
            ...
    
        def _next_index(self):
            # sampler_iter 来自于 index_sampler
            return next(self._sampler_iter)  # may raise StopIteration
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    从这里看出,dataloader 提供了 sampler (可以是batch_sampler 或者是其他 sampler 子类),然后 _SingleProcessDataLoaderIter 迭代sampler获得索引

    下面我们来看看 fetcher,fetcher 需要 index 来获取元素,并同时支持 Map-style dataset(对应 _MapDatasetFetcher)和 Iterable-style dataset(对应 _IterableDatasetFetcher),使其在Dataloader内能使用相同的接口 fetch,代码更加简洁。

    • 对于 Map-style:直接输入索引 index,作为 map 的 key,获得对应的样本(即 value)
    class _MapDatasetFetcher(_BaseDatasetFetcher):
        def __init__(self, dataset, auto_collation, collate_fn, drop_last):
            super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
    
        def fetch(self, possibly_batched_index):
            if self.auto_collation:
                # 有batch_sampler,_auto_collation就为True,
                # 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 对于 Iterable-style: init 方法内设置了 dataset 初始的迭代器,fetch 方法内获取元素,index 其实已经没有多大作用了
    class _IterableDatasetFetcher(_BaseDatasetFetcher):
        def __init__(self, dataset, auto_collation, collate_fn, drop_last):
            super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
            self.dataset_iter = iter(dataset)
    
        def fetch(self, possibly_batched_index):
            if self.auto_collation:
                # 对于batch_sampler(即auto_collation==True)
                # 直接使用往后遍历并提取len(possibly_batched_index)个样本(即1个batch的样本)
                data = []
                for _ in possibly_batched_index:
                    try:
                        data.append(next(self.dataset_iter))
                    except StopIteration:
                        break
                if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                    raise StopIteration
            else:
                # 对于sampler,直接往后遍历并提取1个样本
                data = next(self.dataset_iter)
            return self.collate_fn(data)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    最后,我们通过索引传入 fetcher,fetch 得到想要的样本。
    因此,整个过程调用关系总结 如下:

    loader.__iter__ --> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> __next__() --> self._next_data() --> self._next_index() -->next(self._sampler_iter)next(iter(self._index_sampler)) --> 获得 index --> self._dataset_fetcher.fetch(index) --> 获得 data

    参考链接:PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    问题记录:
    1.pytorch 中的Dataset这个类为什么可以调用__getitem__?
    特殊方法名称
    在这里插入图片描述

  • 相关阅读:
    深入docker-swarm overlay网络模型
    Centos 6.5 升级到Centos7指导手册
    Go切片排序
    netapp3210存储更换故障硬盘过程
    JDK、JRE、JVM 三者关系
    Web-监听器
    安卓文件资源中,一个字串包含引用其他字串的写法
    2022-10-28 开源会议分享开场词
    Linux知识点 -- 高级IO(二)
    计算机毕业设计Java房产销售系统(源码+系统+mysql数据库+lw文档)
  • 原文地址:https://blog.csdn.net/Highlight_Jin/article/details/126206958