• 对齐PyTorch,一文详解OneFlow的DataLoader实现


    e686f3a8ef52002f4f1b4a26a36094db.png

    撰文 | 赵露阳

    在最新的OneFlow v0.5.0版本中,我们增加了许多新特性,比如:

    • 新增动态图特性:OneFlow 默认以动态图模式(eager)运行,与静态图模式(graph)相比,更容易搭建网络、调试和验证算法。

    • 面向对象式的动态图接口 nn.Module,熟悉 PyTorch 的用户可以轻松上手。

    • “一行代码转换 OneFlow 与 PyTorch 网络”:与 PyTorch 对齐的算子数目增加至200+。在 ResNet50、AlexNet 等 十几个常用网络 上已通过 import oneflow as torch 和 import torch as flow 验证。注意:此特性是为方便用户由 PyTorch 迁移至 OneFlow 而设计,并不是承诺完全兼容 PyTorch。

    • 面向对象式的静态图接口:新增面向对象的静态图接口 nn.Graph。保留了 OneFlow 静态图性能优势的同时,让静态图的编程门槛与动态图接近,期待更多的算法工程师把 OneFlow 的高性能优势玩起来。这是一个用 nn.Graph 搭建 ResNet50 示例

    • 易用高效的分布式训练:分布式训练是大势所趋,OneFlow 本版本新增的 Consistent Tensor,让用户可以像操作单机单卡一样,操作整个集群,并立即看到效果。新增的 launch 模块、DDP 模块 配合 OneFlow 的一致性视角 让用户轻松启动分布式训练,无论是 数据并行、模型并行、还是流水并行,OneFlow 均原生支持,易用高效。

    其中,最重要的新特性之一,就是OneFlow的动态图做到了几乎和PyTorch一致,从Tensor、nn.Module、到autograd、functional api等,其中也包括和torch几乎对齐的DataLoader/Dataset设计,笔者有幸开发了OneFlow中的这一模块。

    1. https://github.com/Oneflow-Inc/oneflow/pull/5406
    2. https://github.com/Oneflow-Inc/oneflow/pull/5500
    3. https://github.com/Oneflow-Inc/oneflow/pull/5644
    4. https://github.com/Oneflow-Inc/oneflow/pull/6280

    本文将对OneFlow/PyTorch中的DataLoader原理、工作流程进行梳理:

    • dataloader简介

    • dataloader原理

    • dataloader工作流程

    • multiprocessing dataloader工作原理

    1

    简介

    简单来说,DataLoader是深度学习中必不可少的,用于处理Dataset产生每个iter过程中批量数据和label的一种数据加载器。正如PyTorch文档中的描述:DataLoader,结合了Sampler、Dataset,提供了对某个dataset可迭代的数据集合。DataLoader支持单进程、多进程的加载数据集合。

    2

    dataloader原理


    核心组建

    • Dataloader

    • Dataset

    • Sampler

    • Fetcher

    DataLoader工作原理的简单总结:

    1.Dataloader是负责数据加载的核心;DataLoaderIter是具体执行单位。dataloader进入到每一次iter中都会通过DataloaderIter来处理具体的数据加载过程;

    2.Dataset是数据集的基类,任何自定义数据集都需要继承它并通过重写getitem方法来定义取数据的方式;

    3.Sampler是负责index相关的采样器、每个iter迭代都会通过Sampler生成要采样的数据集的index;

    4.Fetcher更像是数据的收集器。根据Sampler产生的batch个index去数据集中fetch对应的数据、并通过相应的collate_fn方法将获取的数据收集打包成最终可用的形式,返回给DataLoader。

    使用示例

    1.MNIST

    下面用PyTorch官方examples的一个简单例子,用MNIST数据集训练分类网络来说明DataLoader的用法:

    1. transform=transforms.Compose([
    2. transforms.ToTensor(),
    3. transforms.Normalize((0.1307,), (0.3081,))
    4. ])
    5. dataset1 = datasets.MNIST('../data', train=True, download=True,
    6. transform=transform)
    7. dataset2 = datasets.MNIST('../data', train=False,
    8. transform=transform)
    9. train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    10. test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    可以看到,dataset1、dataset2分别是表示数据集的训练集、测试集。在PyTorch中是通过torchvision.datasets.MNIST定义的。MNIST继承自VisionDataset,而VisionDataset则继承自torch.utils.data.Dataset。在MNIST中,实现了数据集最重要的getitem方法,用于根据index取对应数据:

    1. def __getitem__(self, index: int) -> Tuple[Any, Any]:
    2. """
    3. Args:
    4. index (int): Index
    5. Returns:
    6. tuple: (image, target) where target is index of the target class.
    7. """
    8. img, target = self.data[index], int(self.targets[index])
    9. # doing this so that it is consistent with all other datasets
    10. # to return a PIL Image
    11. img = Image.fromarray(img.numpy(), mode='L')
    12. if self.transform is not None:
    13. img = self.transform(img)
    14. if self.target_transform is not None:
    15. target = self.target_transform(target)
    16. return img, target

    在OneFlow中,oneflow.utils.data对应torch.utils.data;flowvision对应torchvision,使用方式几乎完全一致。例如:对应MNIST数据集,即可直接通过flowvision.datasets.MNIST使用。

    dataset1、dataset2定义完成后,传入分别用于训练、验证的dataloader(train_loader、test_loader)。之后,在train/test的循环中,即可迭代dataloader获取每个iter的数据和label:

    1. def train(args, model, device, train_loader, optimizer, epoch):
    2. model.train()
    3. for batch_idx, (data, target) in enumerate(train_loader):
    4. data, target = data.to(device), target.to(device)
    5. optimizer.zero_grad()
    6. output = model(data)
    7. ....

    2.ImageNet

    这里还是用PyTorch官方examples里ImageNet数据集的训练为例:

    1. train_dataset = datasets.ImageFolder(
    2. traindir,
    3. transforms.Compose([
    4. transforms.RandomResizedCrop(224),
    5. transforms.RandomHorizontalFlip(),
    6. transforms.ToTensor(),
    7. normalize,
    8. ]))
    9. if args.distributed:
    10. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    11. else:
    12. train_sampler = None
    13. train_loader = torch.utils.data.DataLoader(
    14. train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    15. num_workers=args.workers, pin_memory=True, sampler=train_sampler)
    16. val_loader = torch.utils.data.DataLoader(
    17. datasets.ImageFolder(valdir, transforms.Compose([
    18. transforms.Resize(256),
    19. transforms.CenterCrop(224),
    20. transforms.ToTensor(),
    21. normalize,
    22. ])),
    23. batch_size=args.batch_size, shuffle=False,
    24. num_workers=args.workers, pin_memory=True)

    可以看见,大体流程和上面的MNIST差不多:

    1.先是构造Dataset,这里为通过datasets.ImageFolder构造。ImageFolder是用于读取/处理以文件夹形式存放的图片数据集:

    1. class ImageFolder(DatasetFolder):
    2. r"""A generic data loader where the images are arranged in this way by default:
    3. .. code-block:: shell
    4. root/dog/xxx.png
    5. root/dog/xxy.png
    6. root/dog/[...]/xxz.png
    7. root/cat/123.png
    8. root/cat/nsdf3.png
    9. root/cat/[...]/asd932_.png
    10. This class inherits from :class:`~vision.datasets.DatasetFolder` so
    11. the same methods can be overridden to customize the dataset.
    12. Args:
    13. root (string): Root directory path.
    14. transform (callable, optional): A function/transform that takes in an PIL image
    15. and returns a transformed version. E.g, ``transforms.RandomCrop``
    16. target_transform (callable, optional): A function/transform that takes in the
    17. target and transforms it.
    18. loader (callable, optional): A function to load an image given its path.
    19. is_valid_file (callable, optional): A function that takes path of an Image file
    20. and check if the file is a valid file (used to check of corrupt files)
    21. Attributes:
    22. classes (list): List of the class names sorted alphabetically.
    23. class_to_idx (dict): Dict with items (class_name, class_index).
    24. imgs (list): List of (image path, class_index) tuples
    25. """
    26. def __init__(
    27. self,
    28. root: str,
    29. transform: Optional[Callable] = None,
    30. target_transform: Optional[Callable] = None,
    31. loader: Callable[[str], Any] = default_loader,
    32. is_valid_file: Optional[Callable[[str], bool]] = None,
    33. ):
    34. super(ImageFolder, self).__init__(
    35. root,
    36. loader,
    37. IMG_EXTENSIONS if is_valid_file is None else None,
    38. transform=transform,
    39. target_transform=target_transform,
    40. is_valid_file=is_valid_file,
    41. )
    42. self.imgs = self.samples

    可以看到其继承自DatasetFolder、初始化时主要参数有:

    • root:图片文件夹路径

    • transform:对经过loader读取到的PIL图片,经过哪些transform处理,如上述的Resize、CenterCrop等

    • loader:一个用于根据path加载图片的图像加载器,通常默认的loader是PIL

    DatasetFolder中实现了Dataset中最重要的getitem方法:

    1. def __getitem__(self, index: int) -> Tuple[Any, Any]:
    2. """
    3. Args:
    4. index (int): Index
    5. Returns:
    6. tuple: (sample, target) where target is class_index of the target class.
    7. """
    8. path, target = self.samples[index]
    9. sample = self.loader(path)
    10. if self.transform is not None:
    11. sample = self.transform(sample)
    12. if self.target_transform is not None:
    13. target = self.target_transform(target)
    14. return sample, target

    通过getitem定义了如何根据index取到相应数据的方式。

    2.其次如果是多机分布式训练,则Sampler需要使用专门为分布式训练设计的DistributedSampler类(否则不用特殊设置,用默认的即可);这里还有个细节,训练集和验证集上,对dataset做了不同的transform,训练集用了RandomResizedCrop、RandomHorizontalFlip;验证集则是Resize、CenterCrop,经过transform后,最终通过ToTensor方法转化成Tensor。

    3.构造用于训练、验证的Dataloader(train_loader、val_loader),后面的使用方式就很简单了,在train/eval的loop中直接使用即可:

    1. for i, (images, target) in enumerate(train_loader):
    2. # measure data loading time
    3. data_time.update(time.time() - end)
    4. if args.gpu is not None:
    5. images = images.cuda(args.gpu, non_blocking=True)
    6. if torch.cuda.is_available():
    7. target = target.cuda(args.gpu, non_blocking=True)
    8. .....


    3

    dataloader工作流程

    下面结合代码看一下主要流程:

    Dataset 

    任何自定义数据集,必须继承Dataset类并实现_getitem__方法,用于定义根据传入的index获取数据的方式。同时,自定义数据集也可选重写len方法,用于判断数据集的size。

    1. class Dataset(Generic[T_co]):
    2. r"""An abstract class representing a :class:`Dataset`.
    3. All datasets that represent a map from keys to data samples should subclass
    4. it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    5. data sample for a given key. Subclasses could also optionally overwrite
    6. :meth:`__len__`, which is expected to return the size of the dataset by many
    7. :class:`~flow.utils.data.Sampler` implementations and the default options
    8. of :class:`~flow.utils.data.DataLoader`.
    9. .. note::
    10. :class:`~flow.utils.data.DataLoader` by default constructs a index
    11. sampler that yields integral indices. To make it work with a map-style
    12. dataset with non-integral indices/keys, a custom sampler must be provided.
    13. """
    14. def __getitem__(self, index) -> T_co:
    15. raise NotImplementedError
    16. def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
    17. return ConcatDataset([self, other])

    DataLoader 

    DataLoader是整个数据处理过程的核心。

    1. class DataLoader(Generic[T_co]):
    2. def __init__(
    3. self,
    4. dataset: Dataset[T_co],
    5. batch_size: Optional[int] = 1,
    6. shuffle: bool = False,
    7. sampler: Optional[Sampler[int]] = None,
    8. batch_sampler: Optional[Sampler[Sequence[int]]] = None,
    9. num_workers: int = 0,
    10. collate_fn: Optional[_collate_fn_t] = None,
    11. drop_last: bool = False,
    12. timeout: float = 0,
    13. worker_init_fn: Optional[_worker_init_fn_t] = None,
    14. multiprocessing_context=None,
    15. generator=None,
    16. *,
    17. prefetch_factor: int = 2,
    18. persistent_workers: bool = False
    19. ):
    20. ...
    21. ...
    22. # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
    23. # since '_BaseDataLoaderIter' references 'DataLoader'.
    24. def __iter__(self) -> "_BaseDataLoaderIter":
    25. # When using a single worker the returned iterator should be
    26. # created everytime to avoid reseting its state
    27. # However, in the case of a multiple workers iterator
    28. # the iterator is only created once in the lifetime of the
    29. # DataLoader object so that workers can be reused
    30. if self.persistent_workers and self.num_workers > 0:
    31. if self._iterator is None:
    32. self._iterator = self._get_iterator()
    33. else:
    34. self._iterator._reset(self)
    35. return self._iterator
    36. else:
    37. return self._get_iterator()
    38. def _get_iterator(self) -> "_BaseDataLoaderIter":
    39. if self.num_workers == 0 or self.num_workers == 1:
    40. return _SingleProcessDataLoaderIter(self)
    41. else:
    42. self.check_worker_number_rationality()
    43. return _MultiProcessingDataLoaderIter(self)

    DataLoader在每一个iter迭代过程中,最重要的就是通过上面的__iter__方法完成取数据和label。__iter__里通过_get_iterator方法获取相应的DataLoaderIter实例。

    • 在单进程下,即_SingleProcessDataLoaderIter

    • 多进程下,即_MultiProcessingDataLoaderIter,他们都继承自_BaseDataLoaderIter

    DataLoaderIter 

    DataLoaderIter负责DataLoader在每个迭代中具体事务的处理。

    1. class _BaseDataLoaderIter(object):
    2. def __init__(self, loader: DataLoader) -> None:
    3. self._dataset = loader.dataset
    4. self._dataset_kind = loader._dataset_kind
    5. self._IterableDataset_len_called = loader._IterableDataset_len_called
    6. self._auto_collation = loader._auto_collation
    7. self._drop_last = loader.drop_last
    8. self._index_sampler = loader._index_sampler
    9. self._num_workers = loader.num_workers
    10. self._prefetch_factor = loader.prefetch_factor
    11. self._pin_memory = False
    12. self._timeout = loader.timeout
    13. self._collate_fn = loader.collate_fn
    14. self._sampler_iter = iter(self._index_sampler)
    15. self._base_seed = flow.tensor([0], dtype=flow.int64).uniform_().numpy().item()
    16. # TODO: flow.empty()
    17. # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item()
    18. self._persistent_workers = loader.persistent_workers
    19. self._num_yielded = 0
    20. self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
    21. self.__class__.__name__
    22. )
    23. def __iter__(self) -> "_BaseDataLoaderIter":
    24. return self
    25. def _reset(self, loader, first_iter=False):
    26. self._sampler_iter = iter(self._index_sampler)
    27. self._num_yielded = 0
    28. self._IterableDataset_len_called = loader._IterableDataset_len_called
    29. def _next_index(self):
    30. return next(self._sampler_iter) # may raise StopIteration
    31. def _next_data(self):
    32. raise NotImplementedError
    33. def __next__(self) -> Any:
    34. if self._sampler_iter is None:
    35. self._reset()
    36. data = self._next_data()
    37. self._num_yielded += 1
    38. if (
    39. self._dataset_kind == _DatasetKind.Iterable
    40. and self._IterableDataset_len_called is not None
    41. and self._num_yielded > self._IterableDataset_len_called
    42. ):
    43. warn_msg = (
    44. "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
    45. "samples have been fetched. "
    46. ).format(self._dataset, self._IterableDataset_len_called, self._num_yielded)
    47. if self._num_workers > 1:
    48. warn_msg += "Multiprocessing dataloader is not support yet!"
    49. warnings.warn(warn_msg)
    50. return data
    51. def __len__(self) -> int:
    52. return len(self._index_sampler)
    53. def __getstate__(self):
    54. raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
    55. class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    56. def __init__(self, loader):
    57. super(_SingleProcessDataLoaderIter, self).__init__(loader)
    58. assert self._timeout == 0
    59. assert 0 <= self._num_workers <= 1
    60. self._dataset_fetcher = _DatasetKind.create_fetcher(
    61. self._dataset_kind,
    62. self._dataset,
    63. self._auto_collation,
    64. self._collate_fn,
    65. self._drop_last,
    66. )
    67. def _next_data(self):
    68. index = self._next_index() # may raise StopIteration
    69. if self._pin_memory:
    70. raise NotImplementedError("Dataloader pin memory is not support yet!")
    71. return self._dataset_fetcher.fetch(index)

    在每一个iter迭代时,会调用_BaseDataLoaderIter的__next__方法,进而调用自类实现的_next_data方法获取数据。以_SingleProcessDataLoaderIter为例:

    • index = self._next_index()通过Sampler获取此次迭代的数据集索引;

    • self._dataset_fetcher.fetch(index)Fetcher根据index索引取相应的数据。

    Fetcher 

    Fetcher作为数据收集器,会根据Sampler产生的batch的index,来从数据集中切分、收集、打包成完整可用的一个batch的数据,并返回给DataLoader使用。

    1. class _MapDatasetFetcher(_BaseDatasetFetcher):
    2. def __init__(self, dataset, auto_collation, collate_fn, drop_last):
    3. super(_MapDatasetFetcher, self).__init__(
    4. dataset, auto_collation, collate_fn, drop_last
    5. )
    6. def fetch(self, possibly_batched_index):
    7. if self.auto_collation:
    8. data = [self.dataset[idx] for idx in possibly_batched_index]
    9. else:
    10. data = self.dataset[possibly_batched_index]
    11. return self.collate_fn(data)

    Fetcher这里和DataLoaderIter(BaseDataLoaderIter)_类似,_都有一个基类的实现BaseDatasetFetcher。根据不同的数据类型,进入到不同的子类实现中,这里以常用的_MapDatasetFetcher的子类实现为例,看一下Fetcher的主要工作。

    可以看见,主要就是:

    • data = [self.dataset[idx] for idx in possibly_batched_index]

    • return self.collate_fn(data)

    1.根据传入的batch个index列表,去dataset中去切分相应的数据,返回的是取出后的batch个数据的列表;

    2.根据传入的或自定义的collate_fn方法,收集处理这batch个数据,并打包成训练/验证时可直接使用的Tensor。

    4

    multiprocessing dataloader工作原理

    原理

    普通的单进程DataLoader在处理每个iter的数据处理是iter-by-iter且同步的,受制于Python没有实际上的多线程执行,所以单进程的DataLoader通常是比较慢的。多进程DataLoader,即通过Python的multiprocessing开启多个Python的worker进程,譬如开启4个worker进程后,理论上每单位时间可以处理4个iter的数据集,加速数据处理/加载的过程。

    单进程DataLoader下,由于数据处理是iter-by-iter的,下一个iter的处理需要等待当前iter完成后才可开始;多进程DataLoader和单进程DataLoader的主要区别就在于可以通过Python的multiprocessing模块,启动多个worker进程加速这个过程。

    这里以4进程的DataLoader为例:

    DataLoader的主线程将当前iter的任务下发给worker1之后,再下发下一个iter的任务给worker2....直至下发第4个iter的处理任务给worker4。这一步骤主要在dataloader.py的L1024-L1026中实现:

    1. # prime the prefetch loop
    2. for _ in range(self._prefetch_factor * self._num_workers):
    3. self._try_put_index()

    陆续发送完index后,这4个worker可以并行的工作,陆续完成自己iter的处理任务后,将结果塞入一个Queue队列中,DataLoader的主线程从队列中取数据即可。

    具体到每个worker的工作流程,其实和单进程的DataLoader工作流程是类似的,下面主要介绍下多进程和单进程DataLoader的区别,以及多个worker之间是如何协同工作的。

    工作流程

    _MultiProcessingDataLoaderIter

    1. def _next_data(self):
    2. # DataLoaderIter通过此方法获取每个iter的数据,主要调用_get_data实现
    3. def _get_data(self):
    4. # _get_data方法中,主要通过调用_try_get_data()获取数据
    5. def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
    6. # 从主进程的_data_queue中获取数据
    7. ...
    8. try:
    9. data = self._data_queue.get(timeout=timeout)
    10. return (True, data)
    11. except Exception as e:
    12. ...
    13. def _process_data(self, data):
    14. # 主要工作即:1.通过_try_put_index()来将下一个iter的index放入一个活跃的worker进程中
    15. # 2.同时标记_rcvd_idx,使其增加1
    16. self._rcvd_idx += 1
    17. self._try_put_index()
    18. if isinstance(data, ExceptionWrapper):
    19. data.reraise()
    20. return data
    21. def _try_put_index(self):
    22. # 主要工作即遍历所有workers,找到第一个活跃的worker(worker_queue_idx标识)
    23. # 将index和_send_idx信息放入此worker的index_queue中
    24. # 每个worker拥有独立的index_queue,收到index_queue的信息后即开始工作
    25. assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
    26. try:
    27. index = self._next_index()
    28. except StopIteration:
    29. return
    30. for _ in range(self._num_workers): # find the next active worker, if any
    31. worker_queue_idx = next(self._worker_queue_idx_cycle)
    32. if self._workers_status[worker_queue_idx]:
    33. break
    34. else:
    35. # not found (i.e., didn't break)
    36. return
    37. self._index_queues[worker_queue_idx].put((self._send_idx, index))
    38. self._task_info[self._send_idx] = (worker_queue_idx,)
    39. self._tasks_outstanding += 1
    40. self._send_idx += 1

    _next_data()

    ⬇️

    _get_data() ➡️ _try_get_data()

    ⬇️

    _process_data() ➡️ _try_put_index()

    每个worker独立工作,主要代码在oneflow/python/oneflow/utils/data/_utils/worker.py的_worker_loop()方法中:

    1. while watchdog.is_alive():
    2. try:
    3. r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
    4. except queue.Empty:
    5. continue
    6. if isinstance(r, _ResumeIteration):
    7. # Acknowledge the main process
    8. data_queue.put((r, None))
    9. iteration_end = False
    10. # Recreate the fetcher for worker-reuse policy
    11. fetcher = _DatasetKind.create_fetcher(
    12. dataset_kind, dataset, auto_collation, collate_fn, drop_last
    13. )
    14. continue
    15. elif r is None:
    16. # Received the final signal
    17. assert done_event.is_set() or iteration_end
    18. break
    19. elif done_event.is_set() or iteration_end:
    20. # `done_event` is set. But I haven't received the final signal
    21. # (None) yet. I will keep continuing until get it, and skip the
    22. # processing steps.
    23. continue
    24. idx, index = r
    25. data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
    26. if init_exception is not None:
    27. data = init_exception
    28. init_exception = None
    29. else:
    30. try:
    31. data = fetcher.fetch(index)
    32. except Exception as e:
    33. if (
    34. isinstance(e, StopIteration)
    35. and dataset_kind == _DatasetKind.Iterable
    36. ):
    37. data = _IterableDatasetStopIteration(worker_id)
    38. # Set `iteration_end`
    39. # (1) to save future `next(...)` calls, and
    40. # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
    41. iteration_end = True
    42. else:
    43. # It is important that we don't store exc_info in a variable.
    44. # `ExceptionWrapper` does the correct thing.
    45. # See NOTE [ Python Traceback Reference Cycle Problem ]
    46. data = ExceptionWrapper(
    47. where="in DataLoader worker process {}".format(worker_id)
    48. )
    49. data_queue.put((idx, data))
    50. del data, idx, index, r # save memory
    51. except KeyboardInterrupt:
    52. # Main process will raise KeyboardInterrupt anyways.
    53. pass

    每个worker在自己的worker loop中,一旦

    r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)获取index_queue中的index数据,就会开始工作:

    idx, index = r >> data = fetcher.fetch(index) 这部分内容和之前描述的单进程DataLoader的工作流程没有区别。

    当获取到处理完成的数据data后,会将其放入到data loader main线程的data_queue中: data_queue.put((idx, data)) 等待DataLoader主线程从queue中获取结果。

    以上即为多进程DataLoader的主要工作流程。

    5

    结语

    本文梳理总结了DataLoader/Dataset,希望能对大家了解OneFlow/PyTorch动态图模式下的DataLoader/Dataset工作原理有所帮助。 

    对齐PyTorch的DataLoader/Dataset只是第一步,后续仍然面临着效率瓶颈等问题,因为即使使用了multiprocess的DataLoader,在某些情况下,图像解码、Python下调用C++ op执行各种transform时仍可能遭遇性能问题,造成训练过程中GPU打不满/等待CPU数据处理等情况,后续需要考虑更高效的解决方案(如Dali等)。

    其他人都在看

    点击“阅读原文”,欢迎下载体验OneFlow新一代开源深度学习框架

    2b4826d5ab945b924d1535b4a846b494.png

  • 相关阅读:
    Telent
    IOS开发者自带弱网测试工具界面说明NETWORK LINK CONDITIONER
    Google Guava Cache LoadingCache 基本使用
    iOS 18 为 iPhone 15 机型引入了更多充电限制选项
    【面试篇】Spring的那些面试题(上)
    matlab工具箱如何设置自己的BP神经网络初始权重
    【Linux】定期切割 catalina.out 和 log 日志
    C#,计算几何,计算机图形学(Computer Graphics)洪水填充算法(Flood Fill Algorithm)与源代码
    SQL注入之WAF绕过技巧
    在使用VSCode软件编写代码时,突然字符之间间隔变大了-----已解决
  • 原文地址:https://blog.csdn.net/OneFlow_Official/article/details/121173756