Registry
注册器来管理具有相似功能的不同模块,例如检测网络中的backbone、head和neck。
什么是Registry
是一个类或函数到字符串的映射构成的集合,一个注册器中的类或函数通常具有相似的接口,我们可以通过字符串从注册器中返回我们需要使用的类或函数。在OpenMMLab中实现注册器的基础是python中的装饰函数。想了解装饰函数作用的可以查看这个博客
https://www.cnblogs.com/Moon-Face/p/14582298.html
简单理解为输入参数为函数的函数
注意:当模块被导入时,注册机制才会被触发
接下来我们通过代码,以faster_rcnn为例,分析在执行注册时到底发生了什么
mmdet/models/builder.py
- from mmcv.cnn import MODELS as MMCV_MODELS
- from mmcv.utils import Registry
-
- MODELS = Registry('models', parent=MMCV_MODELS) #创建一个注册器MODELS
- DETECTORS = MODELS
mmdet/models/detectors/faster_rcnn.py
- @DETECTORS.register_module()#通过该修饰函数进行注册,当该模块被导入时,完成注册操作
- class FasterRCNN(TwoStageDetector):
- """Implementation of `Faster R-CNN
`_""" -
- def __init__(self,
- backbone,
- rpn_head,
- roi_head,
- train_cfg,
- test_cfg,
- neck=None,
- pretrained=None,
- init_cfg=None):
- super(FasterRCNN, self).__init__(
- backbone=backbone,
- neck=neck,
- rpn_head=rpn_head,
- roi_head=roi_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- pretrained=pretrained,
- init_cfg=init_cfg)
mmdet/models/__init__.py
导入该模块执行注册
- from .backbones import * # noqa: F401,F403
- from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
- ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
- build_detector, build_head, build_loss, build_neck,
- build_roi_extractor, build_shared_head)
- from .dense_heads import * # noqa: F401,F403
- from .detectors import * # noqa: F401,F403 #导入该模块执行注册
- from .losses import * # noqa: F401,F403
- from .necks import * # noqa: F401,F403
- from .plugins import * # noqa: F401,F403
- from .roi_heads import * # noqa: F401,F403
- from .seg_heads import * # noqa: F401,F403
-
- __all__ = [
- 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
- 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
- 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
- ]
mmdet/models/builder.py
构建网络,实例化
- def build_detector(cfg, train_cfg=None, test_cfg=None):
- """Build detector."""
- if train_cfg is not None or test_cfg is not None:
- warnings.warn(
- 'train_cfg and test_cfg is deprecated, '
- 'please specify them in model', UserWarning)
- assert cfg.get('train_cfg') is None or train_cfg is None, \
- 'train_cfg specified in both outer field and model field '
- assert cfg.get('test_cfg') is None or test_cfg is None, \
- 'test_cfg specified in both outer field and model field '
- return DETECTORS.build(
- cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
'运行
执行build时完成的操作,执行build_from_cfg函数,
1.通过registry.get()返回对应的类obj_cls。
2.通过obj_cls(**kwargs)实现类的实例化,**kwargs表示剩余的参数为字典类型(形参:值)
由此我们发现注册器的好处在于可以统一管理就有相似功能的模块,仅需要通过@XXX.register_model()装饰函数即可,然后通过上述统一的方式实现调用及实例化
我们若想加入自己的网络模块,完成注册并使用,
1.在mmdet/models/detectors搭建网络类,并用利用注册器装饰函数修饰@xxx.register_module()
2.在mmdet/models/__init__.py导入,执行注册操作
3.在config配置文件里修改model的type为自定义的类名
执行build时完成的操作,执行build_from_cfg函数,这一部分详见registry源码:
mmcv/utils/registry.py
- class Registry:
- """A registry to map strings to classes or functions.
- Registered object could be built from registry. Meanwhile, registered
- functions could be called from registry.
- Example:
- >>> MODELS = Registry('models')
- >>> @MODELS.register_module()
- >>> class ResNet:
- >>> pass
- >>> resnet = MODELS.build(dict(type='ResNet'))
- >>> @MODELS.register_module()
- >>> def resnet50():
- >>> pass
- >>> resnet = MODELS.build(dict(type='resnet50'))
- Please refer to
- https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
- advanced usage.
- Args:
- name (str): Registry name.
- build_func(func, optional): Build function to construct instance from
- Registry, func:`build_from_cfg` is used if neither ``parent`` or
- ``build_func`` is specified. If ``parent`` is specified and
- ``build_func`` is not given, ``build_func`` will be inherited
- from ``parent``. Default: None.
- parent (Registry, optional): Parent registry. The class registered in
- children registry could be built from parent. Default: None.
- scope (str, optional): The scope of registry. It is the key to search
- for children registry. If not specified, scope will be the name of
- the package where class is defined, e.g. mmdet, mmcls, mmseg.
- Default: None.
- """
-
- def __init__(self, name, build_func=None, parent=None, scope=None):
- self._name = name
- self._module_dict = dict()
- self._children = dict()
- self._scope = self.infer_scope() if scope is None else scope
-
- # self.build_func will be set with the following priority:
- # 1. build_func
- # 2. parent.build_func
- # 3. build_from_cfg
- if build_func is None:
- if parent is not None:
- self.build_func = parent.build_func
- else:
- self.build_func = build_from_cfg
- else:
- self.build_func = build_func
- if parent is not None:
- assert isinstance(parent, Registry)
- parent._add_children(self)
- self.parent = parent
- else:
- self.parent = None
- def get(self, key):
- """Get the registry record.
- Args:
- key (str): The class name in string format.
- Returns:
- class: The corresponding class.
- """
- scope, real_key = self.split_scope_key(key)
- if scope is None or scope == self._scope:
- # get from self
- if real_key in self._module_dict:
- return self._module_dict[real_key]
- else:
- # get from self._children
- if scope in self._children:
- return self._children[scope].get(real_key)
- else:
- # goto root
- parent = self.parent
- while parent.parent is not None:
- parent = parent.parent
- return parent.get(key)
-
- def build(self, *args, **kwargs):
- return self.build_func(*args, **kwargs, registry=self)
-
-
- def build_from_cfg(cfg: Dict,
- registry: 'Registry',
- default_args: Optional[Dict] = None) -> Any:
- """Build a module from config dict when it is a class configuration, or
- call a function from config dict when it is a function configuration.
- Example:
- >>> MODELS = Registry('models')
- >>> @MODELS.register_module()
- >>> class ResNet:
- >>> pass
- >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS)
- >>> # Returns an instantiated object
- >>> @MODELS.register_module()
- >>> def resnet50():
- >>> pass
- >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
- >>> # Return a result of the calling function
- Args:
- cfg (dict): Config dict. It should at least contain the key "type".
- registry (:obj:`Registry`): The registry to search the type from.
- default_args (dict, optional): Default initialization arguments.
- Returns:
- object: The constructed object.
- """
- if not isinstance(cfg, dict):
- raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
- if 'type' not in cfg:
- if default_args is None or 'type' not in default_args:
- raise KeyError(
- '`cfg` or `default_args` must contain the key "type", '
- f'but got {cfg}\n{default_args}')
- if not isinstance(registry, Registry):
- raise TypeError('registry must be an mmcv.Registry object, '
- f'but got {type(registry)}')
- if not (isinstance(default_args, dict) or default_args is None):
- raise TypeError('default_args must be a dict or None, '
- f'but got {type(default_args)}')
-
- args = cfg.copy()
-
- if default_args is not None:
- for name, value in default_args.items():
- args.setdefault(name, value)
-
- obj_type = args.pop('type')
- if isinstance(obj_type, str):
- obj_cls = registry.get(obj_type)
- if obj_cls is None:
- raise KeyError(
- f'{obj_type} is not in the {registry.name} registry')
- elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
- obj_cls = obj_type
- else:
- raise TypeError(
- f'type must be a str or valid type, but got {type(obj_type)}')
- try:
- return obj_cls(**args)
- except Exception as e:
- # Normal TypeError does not print class name.
- raise type(e)(f'{obj_cls.__name__}: {e}')