在MMDetection框架中,我们经常会在forward函数中看到下面的代码,以ATSS为例。
def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor]]:
"""Forward features from the upstream network.
Args:
x (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_anchors * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_anchors * 4.
"""
return multi_apply(self.forward_single, x, self.scales)
forward
函数的作用是得到网络的输出。对于检测任务来说通常是三个输出,分别对应分类分支和回归分支以及置信度分支。
上面函数中X是经过NECK后的向量,经过multi_apply
后ATSS返回的是classification scores
和bbox prediction
以及centerness
。
在正式讲解之前,先看一下函数调用流程
通过上图我们可以清晰的看到multi_apply
调用了forward_single
函数,从而得到我们想要的输出。那multi_apply是怎么调用的呢?
在看本章节的时候,务必确保自己懂了上面的前置知识。
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
pfunc = partial(func, **kwargs) if kwargs else func
# func对象是我们传过去的函数,即forward_single,如下图
map_results = map(pfunc, *args)
# *args 就是我们前面传过来的x和scale
# 本句的作用,调用forward_single得到网络的输出。
return tuple(map(list, zip(*map_results)))
# 将网络的输出按组打包。
# 原始的fpn某一层网络输出是(cls,reg,obj),经过zip之后,五层fpn的输出变为了([cls1, cls2, cls3, cls4, cls5], [reg1, reg2, reg3, reg4, reg5], [obj1, obj2, obj3, obj4, obj5])如下图
以上就是对MMDetection中的multi_apply的理解,如有疑问欢迎交流。