• 图像分割项目中损失函数的选择


    前言

    在图像分割领域,最基础、最常见的损失当然是交叉熵损失 —— Cross entropy。随着不断的研究,涌现出了许多优于交叉熵损失的,并且在实际场景中,也往往不会在单单使用交叉熵损失了。

    场景:实际项目中,通常会有一个常见的问题:样本不均衡

    一、focal loss

    focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
      通常情况下,样本不均衡所带来的问题是少样本难以区分(当然也会存在一些本身就很难区分或分割的样本),因此focal loss聚焦于难分样本,在梯度求导时,让难分类样本占主导,因此训练学习过程更加聚焦在难分样本。

    思考

       focal loss在训练过程中本身是一个动态选择,并不稳定,这也是为什么有些情形下使用focal loss还不如原本的CE loss。通常来说,为了防止难易样本的频繁变化,应当选取小的学习率

    代码如下(示例):

    class FocalLoss(nn.Module):
        """
        copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
        This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
        'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
            Focal_Loss= -1*alpha*(1-pt)*log(pt)
        :param num_class:
        :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
        :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                        focus on hard misclassified example
        :param smooth: (float,double) smooth value when cross entropy
        :param balance_index: (int) balance class index, should be specific when alpha is float
        :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
        """
    
        def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-1, size_average=True):
            super(FocalLoss, self).__init__()
            self.apply_nonlin = apply_nonlin
            self.alpha = alpha
            self.gamma = gamma
            self.balance_index = balance_index
            self.smooth = smooth
            self.size_average = size_average
    
            if self.smooth is not None:
                if self.smooth < 0 or self.smooth > 1.0:
                    raise ValueError('smooth value should be in [0,1]')
    
        def forward(self, logit, target):
            N=logit.shape[1]
            self.alpha = enet_weighing(target, N).cuda()
    
            logit = F.softmax(logit, dim=1)
            if self.apply_nonlin is not None:
                logit = self.apply_nonlin(logit)
            num_class = logit.shape[1]
            if logit.dim() > 2:
                # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
                logit = logit.view(logit.size(0), logit.size(1), -1)
                logit = logit.permute(0, 2, 1).contiguous()
                logit = logit.view(-1, logit.size(-1))
            target = torch.squeeze(target, 1)
            target = target.view(-1, 1)
            # print(logit.shape, target.shape)
            #
            alpha = self.alpha
    
            if alpha is None:
                alpha = torch.ones(num_class, 1)
            elif isinstance(alpha, (list, np.ndarray)):
                assert len(alpha) == num_class
                alpha = torch.FloatTensor(alpha).view(num_class, 1)
                alpha = alpha / alpha.sum()
            elif isinstance(alpha, float):
                alpha = torch.ones(num_class, 1)
                alpha = alpha * (1 - self.alpha)
                alpha[self.balance_index] = self.alpha
    
            # else:
            #     raise TypeError('Not support alpha type')
    
            if alpha.device != logit.device:
                alpha = alpha.to(logit.device)
    
            idx = target.cpu().long()
    
            one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
            one_hot_key = one_hot_key.scatter_(1, idx, 1)
            if one_hot_key.device != logit.device:
                one_hot_key = one_hot_key.to(logit.device)
    
            if self.smooth:
                one_hot_key = torch.clamp(
                    one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
            pt = (one_hot_key * logit).sum(1) + self.smooth
            logpt = pt.log()
    
            gamma = self.gamma
    
            alpha = alpha[idx]
            alpha = torch.squeeze(alpha)
            loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
    
            if self.size_average:
                loss = loss.mean()
            else:
                loss = loss.sum()
            return loss
    
    # 训练过程
    focal = FocalLoss()
    FocalLoss1 = focal(out, label) # out:模型输出  label:标签
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92

    二、Dice loss

    Dice loss适用于样本极度不平衡的情况,一般情况下使用Dice Loss会对反向传播不利,使得训练不稳定(注:在使用DICE loss时,对小目标是十分不利的,因为在只有前景和背景的情况下,小目标一旦有部分像素预测错误,那么就会导致Dice大幅度的变动,从而导致梯度变化剧烈,训练不稳定)。因为,通常是将Dice loss作为辅助损失函数来和主损失函数一起训练,如Dice loss+CE loss 或 Dice loss + Focal loss

    代码如下(示例):

    import torch
    from torch import Tensor
    import torch.nn.functional as F
    
    def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
        # Average of Dice coefficient for all batches, or for a single mask
        assert input.size() == target.size()
        if input.dim() == 2 and reduce_batch_first:
            raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
    
        if input.dim() == 2 or reduce_batch_first:
            inter = torch.dot(input.reshape(-1), target.reshape(-1))
            sets_sum = torch.sum(input) + torch.sum(target)
            if sets_sum.item() == 0:
                sets_sum = 2 * inter
    
            return (2 * inter + epsilon) / (sets_sum + epsilon)
        else:
            # compute and average metric for each batch element
            dice = 0
            for i in range(input.shape[0]):
                dice += dice_coeff(input[i, ...], target[i, ...])
            return dice / input.shape[0]
    
    
    def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
        # Average of Dice coefficient for all classes
        assert input.size() == target.size()
        dice = 0
        for channel in range(input.shape[1]):
            dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
    
        return dice / input.shape[1]
    
    
    def dice_loss(input: Tensor, target: Tensor, multiclass: bool = True):
        # Dice loss (objective to minimize) between 0 and 1
        assert input.size() == target.size()
        fn = multiclass_dice_coeff if multiclass else dice_coeff
        return 1 - fn(input, target, reduce_batch_first=True)
    
    # 训练过程
    lossp = dice_loss(F.softmax(out, dim=1).float(),
                     F.one_hot(lb, n_classes).permute(0, 3,1,2).contiguous().float(),  multiclass=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    三、二分类

    图像分割二分类任务一般有两种方式:
    (1)和多分类任务一样,只是最后的输出通道num_class设置为2,所以输出的是一个二通道图。二分类标签label是一个单通道图,数值只有0和1两者。为了让模型的输出图不断逼近于abel,会让输出图先经过一个softmax函数,使其数值归一化到(0,1)之间,即让同一位置上两个通道的值加起来等于1。而对于label,会使用onehot编码,转换成了 num_class=2 个通道的图像。然后就可以让输出图和label进行对应的损失计算了。大致流程如下图所示:
    在这里插入图片描述
    注:

    1)二分类任务,经过softmax后,是同一位置的两个通道值之和为1,若是多分类任务,也就是多个通道之和为1。

    2)二分类label经过one-hot编码,0变为[0,1],1变为[1,0];若是多分类任务,假设为4分类,那label图里就是 [0,1,2,3] 这四个像素值。则one-hot编码如下:
    0 —— 【0,0,0,1】
    1 —— 【0,0,1,0】
    2 —— 【0,1,0,0】
    3 —— 【1,0,0,0】

    3)对于CrossEntropyLoss和FocalLoss,其函数内部自带有处理方式,所以无需改动,直接将输出图和label传进去即可,如上面代码:

    focal = FocalLoss()
    FocalLoss1 = focal(out, label) # out:模型输出  label:标签
    
    loss = torch.nn.CrossEntropyLoss()
    loss = loss(out, label)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    对于Dice loss,需要自己改动输入方式,如上面代码:

    lossp = dice_loss(F.softmax(out, dim=1).float(),
     F.one_hot(lb, n_classes).permute(0, 3, 1, 2).contiguous().float(), multiclass=True)
    
    • 1
    • 2

    (2)第二种方式,是显著性目标检测任务中常用的,只输出单通道,即num_class=1。这时是使用sigmoid函数来对输出图进行归一化到(0,1)之间,由于输出图和label都是单通道图,所以可以直接计算损失。可以参考显著性目标检测论文中常用的损失函数:BCE + IOU (BCE关注像素,IOU关注整体结构,两者一起用其实相当于 CE+Dice)

    注:使用torch.nn.BCELoss(),需要自己对输出图使用sigmoid处理;若使用BCEWithLogitsLoss(),其函数内部有sigmoid处理,就不需要自己加了。

    未完待续

    持续记录以后项目中用到的损失函数

  • 相关阅读:
    【华为机试真题 JAVA】事件推送-100
    SpringBoot之整合WebSocket服务并兼容IE8浏览器的方式
    【MySQL】MySQL的IFNULL()、ISNULL()、NULLIF()函数用法说明
    【STM32】OLED
    嵌入式实时操作系统的设计与开发(内存资源池存储管理)
    Himall商城Web帮助类获得上次请求的url、获得请求的方式、获得请求的主机部分、获取请求的端口号、 获得请求的ip、获得请求的原始url
    我如何才能保护我的私钥?
    C++程序设计-第四/五章 函数和类和对象【期末复习|考研复习】
    SQLAlchemy学习-10. validates()校验器
    代理模式
  • 原文地址:https://blog.csdn.net/qq_43199575/article/details/134284008