写在前面:本人正在学习MMDetection3D的过程中,可能有理解错误,欢迎指正。
在MMDetection3D中,如果需要自定义模型,需要进行类的注册。
先看一下官方SECOND的代码:
- # Copyright (c) OpenMMLab. All rights reserved.
- import warnings
-
- from mmcv.cnn import build_conv_layer, build_norm_layer
- from mmcv.runner import BaseModule
- from torch import nn as nn
-
- from ..builder import BACKBONES
-
-
- @BACKBONES.register_module()
- class SECOND(BaseModule):
- """Backbone network for SECOND/PointPillars/PartA2/MVXNet.
- Args:
- in_channels (int): Input channels.
- out_channels (list[int]): Output channels for multi-scale feature maps.
- layer_nums (list[int]): Number of layers in each stage.
- layer_strides (list[int]): Strides of each stage.
- norm_cfg (dict): Config dict of normalization layers.
- conv_cfg (dict): Config dict of convolutional layers.
- """
-
- def __init__(self,
- in_channels=128,
- out_channels=[128, 128, 256],
- layer_nums=[3, 5, 5],
- layer_strides=[2, 2, 2],
- norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
- conv_cfg=dict(type='Conv2d', bias=False),
- init_cfg=None,
- pretrained=None):
- super(SECOND, self).__init__(init_cfg=init_cfg)
- assert len(layer_strides) == len(layer_nums)
- assert len(out_channels) == len(layer_nums)
-
- in_filters = [in_channels, *out_channels[:-1]]
- # note that when stride > 1, conv2d with same padding isn't
- # equal to pad-conv2d. we should use pad-conv2d.
- blocks = []
- for i, layer_num in enumerate(layer_nums):
- block = [
- build_conv_layer(
- conv_cfg,
- in_filters[i],
- out_channels[i],
- 3,
- stride=layer_strides[i],
- padding=1),
- build_norm_layer(norm_cfg, out_channels[i])[1],
- nn.ReLU(inplace=True),
- ]
- for j in range(layer_num):
- block.append(
- build_conv_layer(
- conv_cfg,
- out_channels[i],
- out_channels[i],
- 3,
- padding=1))
- block.append(build_norm_layer(norm_cfg, out_channels[i])[1])
- block.append(nn.ReLU(inplace=True))
-
- block = nn.Sequential(*block)
- blocks.append(block)
-
- self.blocks = nn.ModuleList(blocks)
-
- assert not (init_cfg and pretrained), \
- 'init_cfg and pretrained cannot be setting at the same time'
- if isinstance(pretrained, str):
- warnings.warn('DeprecationWarning: pretrained is a deprecated, '
- 'please use "init_cfg" instead')
- self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
- else:
- self.init_cfg = dict(type='Kaiming', layer='Conv2d')
-
- def forward(self, x):
- """Forward function.
- Args:
- x (torch.Tensor): Input with shape (N, C, H, W).
- Returns:
- tuple[torch.Tensor]: Multi-scale features.
- """
- outs = []
- for i in range(len(self.blocks)):
- x = self.blocks[i](x)
- outs.append(x)
- return tuple(outs)
可以看到,该模型的实现代码与常规实现基本相同,包含__init__函数和forward函数。最开始的@BACKBONES.register_module()语句在无需实例化类的情况下将该模型注册为主干网络(添加到注册表中)。
注意父类为BaseModule,表明该模型是从头开始搭建的。
若要以某个已有模型为基础建立模型,则父类为基础模型对应的类。
以官方的DynamicVoxelNet代码为例:
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmcv.runner import force_fp32
- from torch.nn import functional as F
-
- from ..builder import DETECTORS
- from .voxelnet import VoxelNet
-
-
- @DETECTORS.register_module()
- class DynamicVoxelNet(VoxelNet):
- r"""VoxelNet using `dynamic voxelization
-
`_. - """
-
- def __init__(self,
- voxel_layer,
- voxel_encoder,
- middle_encoder,
- backbone,
- neck=None,
- bbox_head=None,
- train_cfg=None,
- test_cfg=None,
- pretrained=None,
- init_cfg=None):
- super(DynamicVoxelNet, self).__init__( # 继承父类(VoxelNet)的初始化操作
- voxel_layer=voxel_layer,
- voxel_encoder=voxel_encoder,
- middle_encoder=middle_encoder,
- backbone=backbone,
- neck=neck,
- bbox_head=bbox_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- pretrained=pretrained,
- init_cfg=init_cfg)
-
- def extract_feat(self, points, img_metas):
- """Extract features from points."""
- ... # 略
-
- @torch.no_grad()
- @force_fp32()
- def voxelize(self, points):
- """Apply dynamic voxelization to points.
- Args:
- points (list[torch.Tensor]): Points of each sample.
- Returns:
- tuple[torch.Tensor]: Concatenated points and coordinates.
- """
- ... # 略
可以看到,上述代码没有forward部分的函数,这是因为该类DynamicVoxelNet继承了父类VoxelNet的forward函数。观察VoxelNet的代码:
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- from mmcv.ops import Voxelization
- from mmcv.runner import force_fp32
- from torch.nn import functional as F
-
- from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
- from .. import builder
- from ..builder import DETECTORS
- from .single_stage import SingleStage3DDetector
-
-
- @DETECTORS.register_module()
- class VoxelNet(SingleStage3DDetector):
- r"""`VoxelNet
`_ for 3D detection.""" -
- def __init__(self,
- voxel_layer,
- voxel_encoder,
- middle_encoder,
- backbone,
- neck=None,
- bbox_head=None,
- train_cfg=None,
- test_cfg=None,
- init_cfg=None,
- pretrained=None):
- super(VoxelNet, self).__init__(
- backbone=backbone,
- neck=neck,
- bbox_head=bbox_head,
- train_cfg=train_cfg,
- test_cfg=test_cfg,
- init_cfg=init_cfg,
- pretrained=pretrained)
- self.voxel_layer = Voxelization(**voxel_layer)
- self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
- self.middle_encoder = builder.build_middle_encoder(middle_encoder)
-
- def extract_feat(self, points, img_metas=None):
- """Extract features from points."""
- ... # 略
-
- @torch.no_grad()
- @force_fp32()
- def voxelize(self, points):
- """Apply hard voxelization to points."""
- ... # 略
-
- def forward_train(self,
- points,
- img_metas,
- gt_bboxes_3d,
- gt_labels_3d,
- gt_bboxes_ignore=None):
- """Training forward function.
- Args:
- points (list[torch.Tensor]): Point cloud of each sample.
- img_metas (list[dict]): Meta information of each sample
- gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
- boxes for each sample.
- gt_labels_3d (list[torch.Tensor]): Ground truth labels for
- boxes of each sampole
- gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
- boxes to be ignored. Defaults to None.
- Returns:
- dict: Losses of each branch.
- """
- x = self.extract_feat(points, img_metas)
- outs = self.bbox_head(x)
- loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
- losses = self.bbox_head.loss(
- *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
- return losses
-
- def simple_test(self, points, img_metas, imgs=None, rescale=False):
- """Test function without augmentaiton."""
- ... # 略
-
- def aug_test(self, points, img_metas, imgs=None, rescale=False):
- """Test function with augmentaiton."""
- ... # 略
其包含初始化、extract_feat、voxelize、forward_train、simple_test和aug_test方法,这些方法都被DynamicVoxelNet继承,然后DynamicVoxelNet对extract_feat函数和voxelize函数进行了重写。
在搭建(build)模型时从配置文件的相应type字段取出类名,然后将剩余字段传入该类中进行初始化。
例如,以SECOND为例,在相应的配置文件(如hv_second_secfpn_kitti.py)中,找到model字段的backbone部分:
- backbone=dict(
- type='SECOND',
- in_channels=256,
- layer_nums=[5, 5],
- layer_strides=[1, 2],
- out_channels=[128, 256]),
可以看到,在配置文件中仅需在对应部件内使用type='SECOND'即可表明模型使用SECOND作为主干,后续参数即为SECOND中__init__函数的输入参数。