• MMDetection3D简单教程:模型定义、注册与搭建


    写在前面:本人正在学习MMDetection3D的过程中,可能有理解错误,欢迎指正。

            在MMDetection3D中,如果需要自定义模型,需要进行类的注册。

            该部分需要一定的python编程基础知识(类的继承以及函数修饰符@),不熟悉的可参考这篇文章

    1.从头开始定义模型

            先看一下官方SECOND的代码:

    1. # Copyright (c) OpenMMLab. All rights reserved.
    2. import warnings
    3. from mmcv.cnn import build_conv_layer, build_norm_layer
    4. from mmcv.runner import BaseModule
    5. from torch import nn as nn
    6. from ..builder import BACKBONES
    7. @BACKBONES.register_module()
    8. class SECOND(BaseModule):
    9. """Backbone network for SECOND/PointPillars/PartA2/MVXNet.
    10. Args:
    11. in_channels (int): Input channels.
    12. out_channels (list[int]): Output channels for multi-scale feature maps.
    13. layer_nums (list[int]): Number of layers in each stage.
    14. layer_strides (list[int]): Strides of each stage.
    15. norm_cfg (dict): Config dict of normalization layers.
    16. conv_cfg (dict): Config dict of convolutional layers.
    17. """
    18. def __init__(self,
    19. in_channels=128,
    20. out_channels=[128, 128, 256],
    21. layer_nums=[3, 5, 5],
    22. layer_strides=[2, 2, 2],
    23. norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
    24. conv_cfg=dict(type='Conv2d', bias=False),
    25. init_cfg=None,
    26. pretrained=None):
    27. super(SECOND, self).__init__(init_cfg=init_cfg)
    28. assert len(layer_strides) == len(layer_nums)
    29. assert len(out_channels) == len(layer_nums)
    30. in_filters = [in_channels, *out_channels[:-1]]
    31. # note that when stride > 1, conv2d with same padding isn't
    32. # equal to pad-conv2d. we should use pad-conv2d.
    33. blocks = []
    34. for i, layer_num in enumerate(layer_nums):
    35. block = [
    36. build_conv_layer(
    37. conv_cfg,
    38. in_filters[i],
    39. out_channels[i],
    40. 3,
    41. stride=layer_strides[i],
    42. padding=1),
    43. build_norm_layer(norm_cfg, out_channels[i])[1],
    44. nn.ReLU(inplace=True),
    45. ]
    46. for j in range(layer_num):
    47. block.append(
    48. build_conv_layer(
    49. conv_cfg,
    50. out_channels[i],
    51. out_channels[i],
    52. 3,
    53. padding=1))
    54. block.append(build_norm_layer(norm_cfg, out_channels[i])[1])
    55. block.append(nn.ReLU(inplace=True))
    56. block = nn.Sequential(*block)
    57. blocks.append(block)
    58. self.blocks = nn.ModuleList(blocks)
    59. assert not (init_cfg and pretrained), \
    60. 'init_cfg and pretrained cannot be setting at the same time'
    61. if isinstance(pretrained, str):
    62. warnings.warn('DeprecationWarning: pretrained is a deprecated, '
    63. 'please use "init_cfg" instead')
    64. self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
    65. else:
    66. self.init_cfg = dict(type='Kaiming', layer='Conv2d')
    67. def forward(self, x):
    68. """Forward function.
    69. Args:
    70. x (torch.Tensor): Input with shape (N, C, H, W).
    71. Returns:
    72. tuple[torch.Tensor]: Multi-scale features.
    73. """
    74. outs = []
    75. for i in range(len(self.blocks)):
    76. x = self.blocks[i](x)
    77. outs.append(x)
    78. return tuple(outs)

            可以看到,该模型的实现代码与常规实现基本相同,包含__init__函数和forward函数。最开始的@BACKBONES.register_module()语句在无需实例化类的情况下将该模型注册为主干网络(添加到注册表中)。

            注意父类为BaseModule,表明该模型是从头开始搭建的。

    2.以现有模型为基础定义模型

            若要以某个已有模型为基础建立模型,则父类为基础模型对应的类。

            以官方的DynamicVoxelNet代码为例:

    1. # Copyright (c) OpenMMLab. All rights reserved.
    2. import torch
    3. from mmcv.runner import force_fp32
    4. from torch.nn import functional as F
    5. from ..builder import DETECTORS
    6. from .voxelnet import VoxelNet
    7. @DETECTORS.register_module()
    8. class DynamicVoxelNet(VoxelNet):
    9. r"""VoxelNet using `dynamic voxelization
    10. `_.
    11. """
    12. def __init__(self,
    13. voxel_layer,
    14. voxel_encoder,
    15. middle_encoder,
    16. backbone,
    17. neck=None,
    18. bbox_head=None,
    19. train_cfg=None,
    20. test_cfg=None,
    21. pretrained=None,
    22. init_cfg=None):
    23. super(DynamicVoxelNet, self).__init__( # 继承父类(VoxelNet)的初始化操作
    24. voxel_layer=voxel_layer,
    25. voxel_encoder=voxel_encoder,
    26. middle_encoder=middle_encoder,
    27. backbone=backbone,
    28. neck=neck,
    29. bbox_head=bbox_head,
    30. train_cfg=train_cfg,
    31. test_cfg=test_cfg,
    32. pretrained=pretrained,
    33. init_cfg=init_cfg)
    34. def extract_feat(self, points, img_metas):
    35. """Extract features from points."""
    36. ... # 略
    37. @torch.no_grad()
    38. @force_fp32()
    39. def voxelize(self, points):
    40. """Apply dynamic voxelization to points.
    41. Args:
    42. points (list[torch.Tensor]): Points of each sample.
    43. Returns:
    44. tuple[torch.Tensor]: Concatenated points and coordinates.
    45. """
    46. ... # 略

            可以看到,上述代码没有forward部分的函数,这是因为该类DynamicVoxelNet继承了父类VoxelNet的forward函数。观察VoxelNet的代码:

    1. # Copyright (c) OpenMMLab. All rights reserved.
    2. import torch
    3. from mmcv.ops import Voxelization
    4. from mmcv.runner import force_fp32
    5. from torch.nn import functional as F
    6. from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
    7. from .. import builder
    8. from ..builder import DETECTORS
    9. from .single_stage import SingleStage3DDetector
    10. @DETECTORS.register_module()
    11. class VoxelNet(SingleStage3DDetector):
    12. r"""`VoxelNet `_ for 3D detection."""
    13. def __init__(self,
    14. voxel_layer,
    15. voxel_encoder,
    16. middle_encoder,
    17. backbone,
    18. neck=None,
    19. bbox_head=None,
    20. train_cfg=None,
    21. test_cfg=None,
    22. init_cfg=None,
    23. pretrained=None):
    24. super(VoxelNet, self).__init__(
    25. backbone=backbone,
    26. neck=neck,
    27. bbox_head=bbox_head,
    28. train_cfg=train_cfg,
    29. test_cfg=test_cfg,
    30. init_cfg=init_cfg,
    31. pretrained=pretrained)
    32. self.voxel_layer = Voxelization(**voxel_layer)
    33. self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder)
    34. self.middle_encoder = builder.build_middle_encoder(middle_encoder)
    35. def extract_feat(self, points, img_metas=None):
    36. """Extract features from points."""
    37. ... # 略
    38. @torch.no_grad()
    39. @force_fp32()
    40. def voxelize(self, points):
    41. """Apply hard voxelization to points."""
    42. ... # 略
    43. def forward_train(self,
    44. points,
    45. img_metas,
    46. gt_bboxes_3d,
    47. gt_labels_3d,
    48. gt_bboxes_ignore=None):
    49. """Training forward function.
    50. Args:
    51. points (list[torch.Tensor]): Point cloud of each sample.
    52. img_metas (list[dict]): Meta information of each sample
    53. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
    54. boxes for each sample.
    55. gt_labels_3d (list[torch.Tensor]): Ground truth labels for
    56. boxes of each sampole
    57. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
    58. boxes to be ignored. Defaults to None.
    59. Returns:
    60. dict: Losses of each branch.
    61. """
    62. x = self.extract_feat(points, img_metas)
    63. outs = self.bbox_head(x)
    64. loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
    65. losses = self.bbox_head.loss(
    66. *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
    67. return losses
    68. def simple_test(self, points, img_metas, imgs=None, rescale=False):
    69. """Test function without augmentaiton."""
    70. ... # 略
    71. def aug_test(self, points, img_metas, imgs=None, rescale=False):
    72. """Test function with augmentaiton."""
    73. ... # 略

            其包含初始化、extract_feat、voxelize、forward_train、simple_test和aug_test方法,这些方法都被DynamicVoxelNet继承,然后DynamicVoxelNet对extract_feat函数和voxelize函数进行了重写。

    3.模型搭建

            在搭建(build)模型时从配置文件的相应type字段取出类名,然后将剩余字段传入该类中进行初始化。

            例如,以SECOND为例,在相应的配置文件(如hv_second_secfpn_kitti.py)中,找到model字段的backbone部分:

    1. backbone=dict(
    2. type='SECOND',
    3. in_channels=256,
    4. layer_nums=[5, 5],
    5. layer_strides=[1, 2],
    6. out_channels=[128, 256]),

            可以看到,在配置文件中仅需在对应部件内使用type='SECOND'即可表明模型使用SECOND作为主干,后续参数即为SECOND中__init__函数的输入参数。

  • 相关阅读:
    【CQF Math Class 数学笔记】
    计算机毕业设计Java家电仓储管理系统(源码+系统+mysql数据库+lw文档)
    Python: 开始使用工厂模式设计
    node使用http模块
    MySQL面试重点-1
    持续集成部署-k8s-服务发现-Service:Service、Endpoint、Pod之间的关系与原理
    java计算机毕业设计基于安卓Android的掌上酒店预订APP
    231.2的幂
    TDengine函数大全-系统函数
    城市消费券,拒绝恶意爬取
  • 原文地址:https://blog.csdn.net/weixin_45657478/article/details/126614891