这里我们以MobileNet为例展示如何开发新组件
Create a new file mmrotate/models/backbones/mobilenet.py.
import torch.nn as nn
from mmrotate.models.builder import ROTATED_BACKBONES
@ROTATED_BACKBONES.register_module()
class MobileNet(nn.Module):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # should return a tuple
pass
You can either add the following line to mmrotate/models/backbones/init.py
from .mobilenet import MobileNet
or alternatively add
custom_imports = dict(
imports=['mmrotate.models.backbones.mobilenet'],
allow_failed_imports=False)
到配置文件,以避免修改原始代码。
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...