• mmdet之Loss模块详解



    前言

     该篇介绍mmdet的损失函数部分,后续会逐渐扩充mmdet中损失函数的使用注意事项以及使用方法。


    1、mmdet中损失函数模块简介

    1.1. Loss的注册器

     先来看段代码:mmdet/models/builder.py

    from mmcv.cnn import MODELS as MMCV_MODELS
    from mmcv.utils import Registry
    
    MODELS = Registry('models', parent=MMCV_MODELS) # 此处多了一个parent参数,暂时不予考虑
    
    BACKBONES = MODELS
    NECKS = MODELS
    ROI_EXTRACTORS = MODELS
    SHARED_HEADS = MODELS
    HEADS = MODELS
    LOSSES = MODELS         # Loss 注册器
    DETECTORS = MODELS
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    这里MODELS注册器同时赋予给了其他模块,为何操作后续会在

    1.2. 注册L1 Loss()

    @LOSSES.register_module()
    class L1Loss(nn.Module):
        """L1 loss.
    
        Args:
            reduction (str, optional): The method to reduce the loss.
                Options are "none", "mean" and "sum".
            loss_weight (float, optional): The weight of loss.
        """
    
        def __init__(self, reduction='mean', loss_weight=1.0):
            super(L1Loss, self).__init__()
            self.reduction = reduction
            self.loss_weight = loss_weight
    
        def forward(self,
                    pred,
                    target,
                    weight=None,
                    avg_factor=None,
                    reduction_override=None):
            """Forward function.
    
            Args:
                pred (torch.Tensor): 预测框. 比如[N];
                target (torch.Tensor): 真实值.比如[N];
                weight (torch.Tensor, optional): 每个样本的权重,shape = [N], Defaults to None.
                avg_factor (int, optional): 控制总损失的系数,作用跟loss_weight重了。Defaults to None.
                reduction_override (str, optional): 作用跟reduction重了. Defaults to None.
            """
            assert reduction_override in (None, 'none', 'mean', 'sum')
            reduction = (
                reduction_override if reduction_override else self.reduction)
            loss_bbox = self.loss_weight * l1_loss(
                pred, target, weight, reduction=reduction, avg_factor=avg_factor)
            return loss_bbox
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

     上述初始化参数比较简单,就两个参数:reduction默认是’mean’,即返回损失的均值,loss_weight控制L1 Loss总的权重值。但在forward部分参数就多了:pred和target不必多说,二者shape应该一致,假设在处理bbox二者shape为[1000,4];weight的shape应该和pred的shape一样,控制每个样本对总的损失的权重值;avg_factor和reduction_override用的不多,这两个参数分别和loss_weight和reduction参数撞了,不用管。
     理解了上述各个参数作用,举个实际例子算一下:

    import torch
    from mmdet.models import build_loss
    
    loss_bbox = dict(type='L1Loss', loss_weight=1.0)
    obj = build_loss(loss_bbox)
    
    # 模块计算
    pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
    target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
    loss = obj(pred, target)
    print(loss, 9/8)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

     发现跟实际手算结果一致,简单说下计算流程:通过torch.abs计算每个元素之间的绝对值,然后.mean()方法得到最终的结果,这里除以的是所有元素的个数。比如此处就是2*4=8。
     在举个带weight的版本的:

    import torch
    from mmdet.models import build_loss
    
    loss_bbox = dict(type='L1Loss', loss_weight=1.0)
    obj = build_loss(loss_bbox)
    
    # 模块计算
    pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]])   # [2,4]
    target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
    # 带weight版本的: 最后一个元素的weight =0
    weight = torch.Tensor([[1,1,1,1],[1,1,1,0]])     # [2,4]
    loss = obj(pred, target, weight)
    print(loss, 8/8)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    1.3. 内部实现逻辑

     本质上使用的装饰器实现loss的封装,简单说下调用的流程:
    1)调用forward方法,内部调用了 l1_loss函数;

    @weighted_loss
    def l1_loss(pred, target):
        """L1 loss.
    
        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning target of the prediction.
    
        Returns:
            torch.Tensor: Calculated loss
        """
        if target.numel() == 0:
            return pred.sum() * 0
    
        assert pred.size() == target.size()
        loss = torch.abs(pred - target)  # 对应元素相减
        return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    2)此时碰见 @weighted_loss装饰器,则先跳入装饰器, 注意此时首先不计算l1 loss函数, mmdet/losses/losses/utils.py

    def weighted_loss(loss_func):
        @functools.wraps(loss_func)
        def wrapper(pred,
                    target,
                    weight=None,
                    reduction='mean',
                    avg_factor=None,
                    **kwargs):
            # 获取每个元素之间损失
            loss = loss_func(pred, target, **kwargs) 
            loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
            return loss
    
        return wrapper
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

     首先对loss_func即l1_loss进行了一次包装,即往里面多塞了一些参数**kwargs,然后此时执行l1_loss,得到各个元素之间的loss值。
    3)最后一步,执行weight_reduce_loss来得到损失的最终形式(weight, reduction, avg_factor):

    def reduce_loss(loss, reduction):
        """Reduce loss as specified.
    
        Args:
            loss (Tensor): Elementwise loss tensor.
            reduction (str): Options are "none", "mean" and "sum".
    
        Return:
            Tensor: Reduced loss tensor.
        """
        reduction_enum = F._Reduction.get_enum(reduction)
        # none: 0, elementwise_mean:1, sum: 2
        if reduction_enum == 0:
            return loss
        elif reduction_enum == 1:
            return loss.mean()
        elif reduction_enum == 2:
            return loss.sum()
    
    @mmcv.jit(derivate=True, coderize=True)
    def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
    
        # if weight is specified, apply element-wise weight
        if weight is not None:
            loss = loss * weight
    
        # if avg_factor is not specified, just reduce the loss
        if avg_factor is None:
            loss = reduce_loss(loss, reduction)
        else:
            # if reduction is mean, then average the loss by avg_factor
            if reduction == 'mean':
                loss = loss.sum() / avg_factor
            # if reduction is 'none', then do nothing, otherwise raise an error
            elif reduction != 'none':
                raise ValueError('avg_factor can not be used with reduction="sum"')
        return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37

    1.4. 总结

     基本上mmdet所有损失的计算流程就上述过程,在使用L1 Loss时,不必关心那么多超参,直接build loss然后传入pred和target即可,其余参数基本默认即可。

    总结

     未完待续…

  • 相关阅读:
    浏览器插件开发爬虫记录
    2022年11月编程排行榜
    区块链解决方案-最新全套文件
    SpringBoot - @Bean注解详解
    Python海鲜销售数据可视化和查询推荐系统(毕设作品)
    Spring 更简单的读取和存储对象
    【day10.01】使用select实现服务器并发
    带你一起玩转—Java 数组
    「Python条件结构」显示学号及提示信息
    【Greenhills】MULTIIDE集成第三方的编辑器进行源文件编辑工作
  • 原文地址:https://blog.csdn.net/wulele2/article/details/125469970