在图像分割领域,最基础、最常见的损失当然是交叉熵损失 —— Cross entropy。随着不断的研究,涌现出了许多优于交叉熵损失的,并且在实际场景中,也往往不会在单单使用交叉熵损失了。
focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
通常情况下,样本不均衡所带来的问题是少样本难以区分(当然也会存在一些本身就很难区分或分割的样本),因此focal loss聚焦于难分样本,在梯度求导时,让难分类样本占主导,因此训练学习过程更加聚焦在难分样本。
focal loss在训练过程中本身是一个动态选择,并不稳定,这也是为什么有些情形下使用focal loss还不如原本的CE loss。通常来说,为了防止难易样本的频繁变化,应当选取小的学习率。
代码如下(示例):
class FocalLoss(nn.Module):
"""
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
Focal_Loss= -1*alpha*(1-pt)*log(pt)
:param num_class:
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
focus on hard misclassified example
:param smooth: (float,double) smooth value when cross entropy
:param balance_index: (int) balance class index, should be specific when alpha is float
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
"""
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-1, size_average=True):
super(FocalLoss, self).__init__()
self.apply_nonlin = apply_nonlin
self.alpha = alpha
self.gamma = gamma
self.balance_index = balance_index
self.smooth = smooth
self.size_average = size_average
if self.smooth is not None:
if self.smooth < 0 or self.smooth > 1.0:
raise ValueError('smooth value should be in [0,1]')
def forward(self, logit, target):
N=logit.shape[1]
self.alpha = enet_weighing(target, N).cuda()
logit = F.softmax(logit, dim=1)
if self.apply_nonlin is not None:
logit = self.apply_nonlin(logit)
num_class = logit.shape[1]
if logit.dim() > 2:
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
logit = logit.view(logit.size(0), logit.size(1), -1)
logit = logit.permute(0, 2, 1).contiguous()
logit = logit.view(-1, logit.size(-1))
target = torch.squeeze(target, 1)
target = target.view(-1, 1)
# print(logit.shape, target.shape)
#
alpha = self.alpha
if alpha is None:
alpha = torch.ones(num_class, 1)
elif isinstance(alpha, (list, np.ndarray)):
assert len(alpha) == num_class
alpha = torch.FloatTensor(alpha).view(num_class, 1)
alpha = alpha / alpha.sum()
elif isinstance(alpha, float):
alpha = torch.ones(num_class, 1)
alpha = alpha * (1 - self.alpha)
alpha[self.balance_index] = self.alpha
# else:
# raise TypeError('Not support alpha type')
if alpha.device != logit.device:
alpha = alpha.to(logit.device)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
if one_hot_key.device != logit.device:
one_hot_key = one_hot_key.to(logit.device)
if self.smooth:
one_hot_key = torch.clamp(
one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
pt = (one_hot_key * logit).sum(1) + self.smooth
logpt = pt.log()
gamma = self.gamma
alpha = alpha[idx]
alpha = torch.squeeze(alpha)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
# 训练过程
focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出 label:标签
Dice loss适用于样本极度不平衡的情况,一般情况下使用Dice Loss会对反向传播不利,使得训练不稳定(注:在使用DICE loss时,对小目标是十分不利的,因为在只有前景和背景的情况下,小目标一旦有部分像素预测错误,那么就会导致Dice大幅度的变动,从而导致梯度变化剧烈,训练不稳定)。因为,通常是将Dice loss作为辅助损失函数来和主损失函数一起训练,如Dice loss+CE loss 或 Dice loss + Focal loss
代码如下(示例):
import torch
from torch import Tensor
import torch.nn.functional as F
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all classes
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = True):
# Dice loss (objective to minimize) between 0 and 1
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
# 训练过程
lossp = dice_loss(F.softmax(out, dim=1).float(),
F.one_hot(lb, n_classes).permute(0, 3,1,2).contiguous().float(), multiclass=True)
图像分割二分类任务一般有两种方式:
(1)和多分类任务一样,只是最后的输出通道num_class设置为2,所以输出的是一个二通道图。二分类标签label是一个单通道图,数值只有0和1两者。为了让模型的输出图不断逼近于abel,会让输出图先经过一个softmax函数,使其数值归一化到(0,1)之间,即让同一位置上两个通道的值加起来等于1。而对于label,会使用onehot编码,转换成了 num_class=2 个通道的图像。然后就可以让输出图和label进行对应的损失计算了。大致流程如下图所示:
注:
1)二分类任务,经过softmax后,是同一位置的两个通道值之和为1,若是多分类任务,也就是多个通道之和为1。
2)二分类label经过one-hot编码,0变为[0,1],1变为[1,0];若是多分类任务,假设为4分类,那label图里就是 [0,1,2,3] 这四个像素值。则one-hot编码如下:
0 —— 【0,0,0,1】
1 —— 【0,0,1,0】
2 —— 【0,1,0,0】
3 —— 【1,0,0,0】
3)对于CrossEntropyLoss和FocalLoss,其函数内部自带有处理方式,所以无需改动,直接将输出图和label传进去即可,如上面代码:
focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出 label:标签
loss = torch.nn.CrossEntropyLoss()
loss = loss(out, label)
对于Dice loss,需要自己改动输入方式,如上面代码:
lossp = dice_loss(F.softmax(out, dim=1).float(),
F.one_hot(lb, n_classes).permute(0, 3, 1, 2).contiguous().float(), multiclass=True)
(2)第二种方式,是显著性目标检测任务中常用的,只输出单通道,即num_class=1。这时是使用sigmoid函数来对输出图进行归一化到(0,1)之间,由于输出图和label都是单通道图,所以可以直接计算损失。可以参考显著性目标检测论文中常用的损失函数:BCE + IOU (BCE关注像素,IOU关注整体结构,两者一起用其实相当于 CE+Dice)
注:使用torch.nn.BCELoss(),需要自己对输出图使用sigmoid处理;若使用BCEWithLogitsLoss(),其函数内部有sigmoid处理,就不需要自己加了。
持续记录以后项目中用到的损失函数