• (九)mmdetection源码解读:训练过程中训练数据的调用DataLoader


    一、DataLoader创建过程中

    在训练过程train_detector函数中调用build_dataloader函数

    train_detector(model, datasets, cfg, distributed=False, validate=True)
    #train_detector函数中
    data_loaders = [
            build_dataloader(
                ds,
                cfg.data.samples_per_gpu,
                cfg.data.workers_per_gpu,
                # `num_gpus` will be ignored if distributed
                num_gpus=len(cfg.gpu_ids),
                dist=distributed,
                seed=cfg.seed,
                runner_type=runner_type,
                persistent_workers=cfg.data.get('persistent_workers', False))
            for ds in dataset
        ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    build_dataloader函数中最关键的步骤就是DataLoader类的实例化
    DataLoader:包括 dataset and sampler, an iterable

    def build_dataloader(dataset,
                         samples_per_gpu,
                         workers_per_gpu,
                         num_gpus=1,
                         dist=True,
                         shuffle=True,
                         seed=None,
                         runner_type='EpochBasedRunner',
                         persistent_workers=False,
                         **kwargs):
       
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            batch_sampler=batch_sampler,
            collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
            pin_memory=False,
            worker_init_fn=init_fn,
            **kwargs)
    
        return data_loader
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    可以看到初始化参数里有两种sampler:sampler和batch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。

    class DataLoader(Generic[T_co]):
      
        def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                     shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                     batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                     num_workers: int = 0, collate_fn: _collate_fn_t = None,
                     pin_memory: bool = False, drop_last: bool = False,
                     timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
                     multiprocessing_context=None, generator=None,
                     *, prefetch_factor: int = 2,
                     persistent_workers: bool = False):
            torch._C._log_api_usage_once("python.data_loader")  # type: ignore
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    二、利用实例化data_loaders进行训练

    runner.run(data_loaders, cfg.workflow)
    
    • 1
  • 相关阅读:
    HoloLens2开发环境搭建及部署app
    解决:vue-cli-service不是内部或外部命令
    RocketMQ 重试机制详解及最佳实践
    电脑重装系统后鼠标动不了该怎么解决
    第一章.线性空间和线性变换
    Fabric.js 自定义选框样式
    云原生Spark UI Service在腾讯云云原生数据湖产品DLC的实践
    【图论】Dijkstra 算法求最短路 - 构建邻接矩阵(带权无向图)
    springboot web 03 多环境配置
    【Leetcode刷题Python】密码校验
  • 原文地址:https://blog.csdn.net/m0_37737957/article/details/132713820