• OpenMMLab之Registry机制


    Registry
    注册器来管理具有相似功能的不同模块,例如检测网络中的backbone、head和neck。
    什么是Registry
    是一个类或函数到字符串的映射构成的集合,一个注册器中的类或函数通常具有相似的接口,我们可以通过字符串从注册器中返回我们需要使用的类或函数。在OpenMMLab中实现注册器的基础是python中的装饰函数。想了解装饰函数作用的可以查看这个博客
    https://www.cnblogs.com/Moon-Face/p/14582298.html
    简单理解为输入参数为函数的函数
    注意:当模块被导入时,注册机制才会被触发
    接下来我们通过代码,以faster_rcnn为例,分析在执行注册时到底发生了什么
    mmdet/models/builder.py

    1. from mmcv.cnn import MODELS as MMCV_MODELS
    2. from mmcv.utils import Registry
    3. MODELS = Registry('models', parent=MMCV_MODELS) #创建一个注册器MODELS
    4. DETECTORS = MODELS

    mmdet/models/detectors/faster_rcnn.py

    1. @DETECTORS.register_module()#通过该修饰函数进行注册,当该模块被导入时,完成注册操作
    2. class FasterRCNN(TwoStageDetector):
    3. """Implementation of `Faster R-CNN `_"""
    4. def __init__(self,
    5. backbone,
    6. rpn_head,
    7. roi_head,
    8. train_cfg,
    9. test_cfg,
    10. neck=None,
    11. pretrained=None,
    12. init_cfg=None):
    13. super(FasterRCNN, self).__init__(
    14. backbone=backbone,
    15. neck=neck,
    16. rpn_head=rpn_head,
    17. roi_head=roi_head,
    18. train_cfg=train_cfg,
    19. test_cfg=test_cfg,
    20. pretrained=pretrained,
    21. init_cfg=init_cfg)

    mmdet/models/__init__.py
    导入该模块执行注册

    1. from .backbones import * # noqa: F401,F403
    2. from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
    3. ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
    4. build_detector, build_head, build_loss, build_neck,
    5. build_roi_extractor, build_shared_head)
    6. from .dense_heads import * # noqa: F401,F403
    7. from .detectors import * # noqa: F401,F403 #导入该模块执行注册
    8. from .losses import * # noqa: F401,F403
    9. from .necks import * # noqa: F401,F403
    10. from .plugins import * # noqa: F401,F403
    11. from .roi_heads import * # noqa: F401,F403
    12. from .seg_heads import * # noqa: F401,F403
    13. __all__ = [
    14. 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
    15. 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
    16. 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
    17. ]

    mmdet/models/builder.py
    构建网络,实例化

    1. def build_detector(cfg, train_cfg=None, test_cfg=None):
    2. """Build detector."""
    3. if train_cfg is not None or test_cfg is not None:
    4. warnings.warn(
    5. 'train_cfg and test_cfg is deprecated, '
    6. 'please specify them in model', UserWarning)
    7. assert cfg.get('train_cfg') is None or train_cfg is None, \
    8. 'train_cfg specified in both outer field and model field '
    9. assert cfg.get('test_cfg') is None or test_cfg is None, \
    10. 'test_cfg specified in both outer field and model field '
    11. return DETECTORS.build(
    12. cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
    '
    运行

    执行build时完成的操作,执行build_from_cfg函数,
    1.通过registry.get()返回对应的类obj_cls
    2.通过obj_cls(**kwargs)实现类的实例化,**kwargs表示剩余的参数为字典类型(形参:值)
    由此我们发现注册器的好处在于可以统一管理就有相似功能的模块,仅需要通过@XXX.register_model()装饰函数即可,然后通过上述统一的方式实现调用及实例化
    我们若想加入自己的网络模块,完成注册并使用,
    1.在mmdet/models/detectors搭建网络类,并用利用注册器装饰函数修饰@xxx.register_module()
    2.在mmdet/models/__init__.py导入,执行注册操作
    3.在config配置文件里修改model的type为自定义的类名

    执行build时完成的操作,执行build_from_cfg函数,这一部分详见registry源码:
    mmcv/utils/registry.py

    1. class Registry:
    2. """A registry to map strings to classes or functions.
    3. Registered object could be built from registry. Meanwhile, registered
    4. functions could be called from registry.
    5. Example:
    6. >>> MODELS = Registry('models')
    7. >>> @MODELS.register_module()
    8. >>> class ResNet:
    9. >>> pass
    10. >>> resnet = MODELS.build(dict(type='ResNet'))
    11. >>> @MODELS.register_module()
    12. >>> def resnet50():
    13. >>> pass
    14. >>> resnet = MODELS.build(dict(type='resnet50'))
    15. Please refer to
    16. https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
    17. advanced usage.
    18. Args:
    19. name (str): Registry name.
    20. build_func(func, optional): Build function to construct instance from
    21. Registry, func:`build_from_cfg` is used if neither ``parent`` or
    22. ``build_func`` is specified. If ``parent`` is specified and
    23. ``build_func`` is not given, ``build_func`` will be inherited
    24. from ``parent``. Default: None.
    25. parent (Registry, optional): Parent registry. The class registered in
    26. children registry could be built from parent. Default: None.
    27. scope (str, optional): The scope of registry. It is the key to search
    28. for children registry. If not specified, scope will be the name of
    29. the package where class is defined, e.g. mmdet, mmcls, mmseg.
    30. Default: None.
    31. """
    32. def __init__(self, name, build_func=None, parent=None, scope=None):
    33. self._name = name
    34. self._module_dict = dict()
    35. self._children = dict()
    36. self._scope = self.infer_scope() if scope is None else scope
    37. # self.build_func will be set with the following priority:
    38. # 1. build_func
    39. # 2. parent.build_func
    40. # 3. build_from_cfg
    41. if build_func is None:
    42. if parent is not None:
    43. self.build_func = parent.build_func
    44. else:
    45. self.build_func = build_from_cfg
    46. else:
    47. self.build_func = build_func
    48. if parent is not None:
    49. assert isinstance(parent, Registry)
    50. parent._add_children(self)
    51. self.parent = parent
    52. else:
    53. self.parent = None
    54. def get(self, key):
    55. """Get the registry record.
    56. Args:
    57. key (str): The class name in string format.
    58. Returns:
    59. class: The corresponding class.
    60. """
    61. scope, real_key = self.split_scope_key(key)
    62. if scope is None or scope == self._scope:
    63. # get from self
    64. if real_key in self._module_dict:
    65. return self._module_dict[real_key]
    66. else:
    67. # get from self._children
    68. if scope in self._children:
    69. return self._children[scope].get(real_key)
    70. else:
    71. # goto root
    72. parent = self.parent
    73. while parent.parent is not None:
    74. parent = parent.parent
    75. return parent.get(key)
    76. def build(self, *args, **kwargs):
    77. return self.build_func(*args, **kwargs, registry=self)
    78. def build_from_cfg(cfg: Dict,
    79. registry: 'Registry',
    80. default_args: Optional[Dict] = None) -> Any:
    81. """Build a module from config dict when it is a class configuration, or
    82. call a function from config dict when it is a function configuration.
    83. Example:
    84. >>> MODELS = Registry('models')
    85. >>> @MODELS.register_module()
    86. >>> class ResNet:
    87. >>> pass
    88. >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
    89. >>> # Returns an instantiated object
    90. >>> @MODELS.register_module()
    91. >>> def resnet50():
    92. >>> pass
    93. >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
    94. >>> # Return a result of the calling function
    95. Args:
    96. cfg (dict): Config dict. It should at least contain the key "type".
    97. registry (:obj:`Registry`): The registry to search the type from.
    98. default_args (dict, optional): Default initialization arguments.
    99. Returns:
    100. object: The constructed object.
    101. """
    102. if not isinstance(cfg, dict):
    103. raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    104. if 'type' not in cfg:
    105. if default_args is None or 'type' not in default_args:
    106. raise KeyError(
    107. '`cfg` or `default_args` must contain the key "type", '
    108. f'but got {cfg}\n{default_args}')
    109. if not isinstance(registry, Registry):
    110. raise TypeError('registry must be an mmcv.Registry object, '
    111. f'but got {type(registry)}')
    112. if not (isinstance(default_args, dict) or default_args is None):
    113. raise TypeError('default_args must be a dict or None, '
    114. f'but got {type(default_args)}')
    115. args = cfg.copy()
    116. if default_args is not None:
    117. for name, value in default_args.items():
    118. args.setdefault(name, value)
    119. obj_type = args.pop('type')
    120. if isinstance(obj_type, str):
    121. obj_cls = registry.get(obj_type)
    122. if obj_cls is None:
    123. raise KeyError(
    124. f'{obj_type} is not in the {registry.name} registry')
    125. elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
    126. obj_cls = obj_type
    127. else:
    128. raise TypeError(
    129. f'type must be a str or valid type, but got {type(obj_type)}')
    130. try:
    131. return obj_cls(**args)
    132. except Exception as e:
    133. # Normal TypeError does not print class name.
    134. raise type(e)(f'{obj_cls.__name__}: {e}')

  • 相关阅读:
    复习C语言
    类基本用法
    为什么在变频器场合需要安科瑞的电力有源滤波器?
    shell脚本下用plot给iostat和CPU的采样画图的写法
    “秘密入职”字节跳动,百度高级经理一审被判赔107万
    并查集 力扣990等式方程的可满足性
    数据结构之B数
    【5G NAS】5G SUPI 和 SUCI 标识符详解
    Android移动应用开发之使用room实现数据库的增删改查
    互联网+教育时代,线下教培机构的新机遇
  • 原文地址:https://blog.csdn.net/qq_41368074/article/details/127113284