• MMCV学习——基础篇2(Runner)| 九千字:含示例代码教程


    MMCV学习——基础篇2(Runner)

    Runner类是MMCV的一个核心组件,它是一个用来管理训练流程的引擎,并且支持用户用少量代码按照它提供的接口定制化修改训练流程。下面博主按照官方Doc的思路再结合自己的理解讲解一下它。

    1. 主要特性

     Runner的目的就是给用户提供统一的训练流程管理,并支持弹性、可配置的定制化修改(通过Hook实现),因此其主要特性如下:

    • 默认支持Epoch和Iter为基础迭代训练的EpochBasedRunnerIterBasedRunnner,同时也支持用户实现自定义Runner。
    • 支持自定义的工作流以满足训练过程中各状态自由切换,目前支持训练(train)和验证(val)两个工作流。
    • 配合各类钩子函数(Hook),对外提供了灵活的扩展能力,注入不同类型的 Hook,就可以在训练过程中以一种优雅的方式实现扩展功能。

    2. EpochBasedRunner & IterBasedRunner

     顾名思义,EpochBasedRunner就是以Epoch为基础迭代的Runner,下面我们实现一个简单的例子去演示它的工作流workflow控制原理。

    class ToyRunner(nn.Module):
    
        def __init__(self):
            super().__init__()
        
        def train(self, data_loader, **kwargs):
            print(data_loader)
            
    
        def val(self, data_loader, **kwargs):
            print(data_loader)
        
        def run(self):
            # training epochs
            max_epochs = 3
            curr_epoch = 0
            # denotes 2 epochs for training and 1 epoch for validation
            workflow = [("train", 2), ("val", 1)]
            data_loaders = ["dl_train", "dl_val"]
            # the condition to stop training
            while curr_epoch < max_epochs: 
                # workflow
                for i, flow in enumerate(workflow):
                    mode, epochs = flow
                    epoch_func = getattr(self, mode)
                    for _ in range(epochs):
                        if mode == 'train' and curr_epoch >= max_epochs:
                            break
                        epoch_func(f'data_loader: {data_loaders[i]}, epoch={curr_epoch}')
                        if mode == 'train':
                            # validation doesn't affect curr_epoch
                            curr_epoch += 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

     然后我们运行一下ToyRunner:

    runner = ToyRunner()
    runner.run()
    """
    Output:
    data_loader: dl_train, epoch=0
    data_loader: dl_train, epoch=1
    data_loader: dl_val, epoch=2
    data_loader: dl_train, epoch=2
    data_loader: dl_val, epoch=3
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

     上面代码逻辑十分简单,博主这里说一下有几点需要注意的:

    • workflow在这里代表训练2个epochs之后再验证1个epoch,其长度为2需要和data_loaders长度一致。可以理解为这里的workflow有train和val两个flow,因此就需要dl_train和dl_val这两个data_loader提供数据。
    • max_epoch代表的是训练epoch,所以只有当mode为train的时候epoch才会增加
    • 只有当mode为train并且curr_epoch>=max_epochs时才会break,这就保证了最后一次train epoch一定会被验证。
    • IterBasedRunner原理类似,并且两者都继承了一个BaseRunner的基础类,这里就不再赘述了,有兴趣的读者可以点击蓝色字体去看GitHub的源代码。

    3. A Simple Example

    3.1 Tool Function

     接下来我们通过一个简单的例子,按照mmcv的规范去使用一下mmcv提供的Runner类。首先,我们定义一些构建数据集的工具函数

    import platform
    import random
    from functools import partial
    
    import numpy as np
    import torch
    from mmcv.parallel import collate
    from mmcv.runner import get_dist_info
    from mmcv.utils import digit_version
    from torch.utils.data import DataLoader, IterableDataset
    
    
    if platform.system() != 'Windows':
        # https://github.com/pytorch/pytorch/issues/973
        import resource
        rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
        base_soft_limit = rlimit[0]
        hard_limit = rlimit[1]
        soft_limit = min(max(4096, base_soft_limit), hard_limit)
        resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
    
    
    def build_dataloader(dataset,
                         samples_per_gpu,
                         workers_per_gpu,
                         num_gpus=1,
                         dist=False,
                         shuffle=True,
                         seed=None,
                         drop_last=False,
                         pin_memory=True,
                         persistent_workers=True,
                         **kwargs):
        """Build PyTorch DataLoader.
        In distributed training, each GPU/process has a dataloader.
        In non-distributed training, there is only one dataloader for all GPUs.
        Args:
            dataset (Dataset): A PyTorch dataset.
            samples_per_gpu (int): Number of training samples on each GPU, i.e.,
                batch size of each GPU.
            workers_per_gpu (int): How many subprocesses to use for data loading
                for each GPU.
            num_gpus (int): Number of GPUs. Only used in non-distributed training.
            dist (bool): Distributed training/test or not. Default: True.
            shuffle (bool): Whether to shuffle the data at every epoch.
                Default: True.
            seed (int | None): Seed to be used. Default: None.
            drop_last (bool): Whether to drop the last incomplete batch in epoch.
                Default: False
            pin_memory (bool): Whether to use pin_memory in DataLoader.
                Default: True
            persistent_workers (bool): If True, the data loader will not shutdown
                the worker processes after a dataset has been consumed once.
                This allows to maintain the workers Dataset instances alive.
                The argument also has effect in PyTorch>=1.7.0.
                Default: True
            kwargs: any keyword argument to be used to initialize DataLoader
        Returns:
            DataLoader: A PyTorch dataloader.
        """
        rank, world_size = get_dist_info()
        if dist and not isinstance(dataset, IterableDataset):
            # not support dist for notebook
            pass
        elif dist:
            sampler = None
            shuffle = False
            batch_size = samples_per_gpu
            num_workers = workers_per_gpu
        else:
            sampler = None
            batch_size = num_gpus * samples_per_gpu
            num_workers = num_gpus * workers_per_gpu
    
        init_fn = partial(
            worker_init_fn, num_workers=num_workers, rank=rank,
            seed=seed) if seed is not None else None
    
        if digit_version(torch.__version__) >= digit_version('1.8.0'):
            data_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                sampler=sampler,
                num_workers=num_workers,
                collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
                pin_memory=pin_memory,
                shuffle=shuffle,
                worker_init_fn=init_fn,
                drop_last=drop_last,
                persistent_workers=persistent_workers,
                **kwargs)
        else:
            data_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                sampler=sampler,
                num_workers=num_workers,
                collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
                pin_memory=pin_memory,
                shuffle=shuffle,
                worker_init_fn=init_fn,
                drop_last=drop_last,
                **kwargs)
    
        return data_loader
    
    
    def worker_init_fn(worker_id, num_workers, rank, seed):
        """Worker init func for dataloader.
        The seed of each worker equals to num_worker * rank + worker_id + user_seed
        Args:
            worker_id (int): Worker id.
            num_workers (int): Number of workers.
            rank (int): The rank of current process.
            seed (int): The random seed to use.
        """
    
        worker_seed = num_workers * rank + worker_id + seed
        np.random.seed(worker_seed)
        random.seed(worker_seed)
        torch.manual_seed(worker_seed)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121

     上述的工具函数来自MMSegmentation,我在这里为了运行演示方便做了一些删减。

    3.2 Build Model

     接下来我们按照Runner的要求编写一个简单的模型

    from mmcv.runner import BaseModule
    
    
    class ToyModel(BaseModule):
        def __init__(self):
            super().__init__()
            self.backbone = nn.Linear(10, 2)
            self.criterion = nn.CrossEntropyLoss()
    
        def forward(self, x):
            out = self.backbone(x)
            return out
    
        def train_step(self, data_batch, optimizer, **kwargs):
            labels, imgs = data_batch
            preds = self(imgs)
            loss = self.criterion(preds, labels)
            log_vars = dict(train_loss=loss.item())
            num_samples = len(imgs)
            outputs = dict(loss=loss,
                           preds=preds,
                           log_vars=log_vars,
                           num_samples=num_samples)
            return outputs
    
        def val_step(self, data_batch, optimizer, **kwargs):
            labels, imgs = data_batch
            preds = self(imgs)
            loss = self.criterion(preds, labels)
            log_vars = dict(val_loss=loss.item())
            num_samples = len(imgs)
            outputs = dict(log_vars=log_vars,
                           num_samples=num_samples)
            return outputs
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

     上述的模型代码有几点需要注意的:

    • 我们再mmcv的框架下实现自己的Module时需要继承mmcv.runner下的BaseModule,它主要有三个属性/方法:1)init_cfg用来控制模型初始化的配置;2)init_weights参数初始化的函数,记录着参数初始化的信息;3)_params_init_info用来追踪参数初始化信息的defaultdict,该属性仅仅在init_weights函数执行的时候生成,在所有参数初始化完成之后删除。
    • 按照Runner的要求,我们的模型需要实现train_stepval_step两个方法。
    • train_stepval_step的返回结果都是一个字典,其中train_step返回的loss是为了后续在optimizer hook中进行反向传播,而log_varsnum_samples则是log hook输出需要的变量。

    3.3 Build Dataset

     接下来是一个只有一个类别的简单数据集类:

    class ToyDataset(torch.utils.data.Dataset):
    
        def __init__(self, data) -> None:
            super().__init__()
            self.data = data
    
        def __getitem__(self, idx):
            return 0, self.data[idx]
        
        def __len__(self):
            return len(self.data)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    3.4 Running Script

     我们写一段脚本去运行mmcv的Runner并查看效果,下面是初始化参数的定义代码:

    from mmcv import ConfigDict
    from mmcv.runner import build_optimizer
    
    # initialize
    model = ToyModel()  # model
    cfg = ConfigDict(data=dict(samples_per_gpu=2, workers_per_gpu=1),
                     workflow=[('train', 1), ('val', 1)],
                     optimizer=dict(type='SGD',
                                    lr=0.1,
                                    momentum=0.9,
                                    weight_decay=0.0001),
                     lr_config=dict(policy='step', step=[100, 150]),
                     log_config=dict(interval=1,
                                     hooks=[
                                         dict(type='TextLoggerHook',
                                              by_epoch=True),
                                     ]),
                     runner=dict(type='EpochBasedRunner', max_epochs=3))  # config
    optimizer = build_optimizer(model, cfg.optimizer)  # optimizer
    ds_train = ToyDataset(torch.rand(5, 10))  # training loader
    ds_val = ToyDataset(torch.rand(3, 10))  # vallidation loader
    datasets = [ds_train, ds_val]  # dataset
    # data_loaders
    data_loaders = [
        build_dataloader(ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu)
        for ds in datasets
    ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

     然后我们build runner并注册好参数就可以一键训练模型了:

    from mmcv.runner import build_runner
    from mmcv.utils import get_logger
    
    
    # get logger
    logger = get_logger(name='toyproj', log_file='logger.log')
    # initialize runner
    runner = build_runner(cfg.runner,
                          default_args=dict(model=model,
                                            batch_processor=None,
                                            optimizer=optimizer,
                                            work_dir='./',
                                            logger=logger))
    # register default hooks necessary for training
    runner.register_training_hooks(
        # configs of learning rate, it is typically set as:
        lr_config=cfg.lr_config,
        # configuration of logs
        log_config=cfg.log_config)
    # start running
    runner.run(data_loaders, cfg.workflow)
    
    '''
    Output:
    2022-11-23 11:55:50,539 - toyproj - INFO - workflow: [('train', 1), ('val', 1)], max: 3 epochs
    2022-11-23 11:55:55,541 - toyproj - INFO - Epoch [1][1/3]	lr: 1.000e-01, eta: 0:00:39, time: 4.998, data_time: 4.995, memory: 0, train_loss: 0.6041
    2022-11-23 11:55:55,548 - toyproj - INFO - Epoch [1][2/3]	lr: 1.000e-01, eta: 0:00:17, time: 0.009, data_time: 0.009, memory: 0, train_loss: 0.5542
    2022-11-23 11:55:55,552 - toyproj - INFO - Epoch [1][3/3]	lr: 1.000e-01, eta: 0:00:10, time: 0.003, data_time: 0.003, memory: 0, train_loss: 0.5444
    2022-11-23 11:55:57,674 - toyproj - INFO - Epoch(val) [1][2]	val_loss: 0.3617
    2022-11-23 11:55:59,746 - toyproj - INFO - Epoch [2][1/3]	lr: 1.000e-01, eta: 0:00:08, time: 2.067, data_time: 2.066, memory: 0, train_loss: 0.5542
    2022-11-23 11:55:59,752 - toyproj - INFO - Epoch [2][2/3]	lr: 1.000e-01, eta: 0:00:05, time: 0.007, data_time: 0.007, memory: 0, train_loss: 0.6041
    2022-11-23 11:55:59,756 - toyproj - INFO - Epoch [2][3/3]	lr: 1.000e-01, eta: 0:00:03, time: 0.004, data_time: 0.004, memory: 0, train_loss: 0.5444
    2022-11-23 11:56:01,918 - toyproj - INFO - Epoch(val) [2][2]	val_loss: 0.3617
    2022-11-23 11:56:03,991 - toyproj - INFO - Epoch [3][1/3]	lr: 1.000e-01, eta: 0:00:02, time: 2.068, data_time: 2.067, memory: 0, train_loss: 0.6041
    2022-11-23 11:56:03,998 - toyproj - INFO - Epoch [3][2/3]	lr: 1.000e-01, eta: 0:00:01, time: 0.008, data_time: 0.007, memory: 0, train_loss: 0.5897
    2022-11-23 11:56:04,002 - toyproj - INFO - Epoch [3][3/3]	lr: 1.000e-01, eta: 0:00:00, time: 0.004, data_time: 0.004, memory: 0, train_loss: 0.4735
    2022-11-23 11:56:06,181 - toyproj - INFO - Epoch(val) [3][2]	val_loss: 0.3617
    
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
  • 相关阅读:
    数据分析的基本要求:学习数据分析需要掌握哪些能力
    vTESTstudio的使用
    Java JVM中的栈空间怎么释放
    向量数据库,展望AGI时代
    com.mysql.jdbc.Driver过时警告
    为什么禁止MyBatis批量插入几千条数据使用foreach?
    中成药数据图谱可视化与知识问答平台研究
    第二十九篇 动态组件 - component
    使用new/delete动态管理内存【C/C++内存分布】
    Nacos客户端启动出现9848端口错误分析(非版本升级问题)
  • 原文地址:https://blog.csdn.net/qq_42718887/article/details/127983839