• Label Smoothing介绍及其代码实现


    一、标签平滑(Label Smoothing)介绍

    标签平滑(Label Smoothing)的原理其实很简单,它大部分的用处用一句话总结就是:

    修改数据集的标签来增加扰动,避免模型的判断过于自信从而陷入过拟合

    标签平滑是一种正则化的技术,常常用在分类任务中。它的具体做法就是为数据集的标签增加扰动,它的具体做法如下所示(以K分类任务为例)。

    对于K分类来说,假设一个样本 x x x属于第2类,那么实际上用来训练模型(或者说用来计算损失函数)的标签是一个独热编码,具体为[0,0,1,0], 即在位置为2处数值为1(代表属于第2类(从第0类开始计数))。此时标签平滑的具体步骤为:

    1. 定义一个小的扰动常量 ϵ \epsilon ϵ
    2. 将独热编码的标签中的0替换为 ϵ / K \epsilon/K ϵ/K
    3. 将独热编码的标签中的1替换为 1 − e p s i l o n / K 1-epsilon/K 1epsilon/K

    由于在现实数据集中,并不是所有标签都是正确标注的,所以直接最大化 log ⁡ p ( y ∣ x ) \log{p}\left(y\mid{x}\right) logp(yx)(即过于自信的把其中一个候选类对应的digit置为1,将其余类的digit置为零), 反而是有害的。这种过于自信的做法不仅仅会使得模型过拟合,而且有可能拟合到错误的例子上去。

    一些实验已经证明,标签平滑能够增加模型的泛化能力(Müller et al., 2020)。

    二、标签平滑的代码实现

    作为一种相对成熟的技术,标签平滑已经有许多开箱即用的实现,在这里我摘取CoinCheung的实现作为示例。

    这个版本主要面向pytorch框架,你可以像使用pytorch中的CrossEntropyLoss一样使用它,无需任何改动。

    # version 1: use torch.autograd
    class LabelSmoothSoftmaxCEV1(nn.Module):
        '''
        This is the autograd version, you can also try the LabelSmoothSoftmaxCEV2 that uses derived gradients
        '''
    
        def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
            super(LabelSmoothSoftmaxCEV1, self).__init__()
            self.lb_smooth = lb_smooth
            self.reduction = reduction
            self.lb_ignore = ignore_index
            self.log_softmax = nn.LogSoftmax(dim=1)
    
        def forward(self, logits, label):
            '''
            Same usage method as nn.CrossEntropyLoss:
                >>> criteria = LabelSmoothSoftmaxCEV1()
                >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
                >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
                >>> loss = criteria(logits, lbs)
            '''
            # overcome ignored label
            logits = logits.float() # use fp32 to avoid nan
            with torch.no_grad():
                num_classes = logits.size(1)
                label = label.clone().detach()
                ignore = label.eq(self.lb_ignore)
                n_valid = ignore.eq(0).sum()
                label[ignore] = 0
                lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
                lb_one_hot = torch.empty_like(logits).fill_(
                    lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
    
            logs = self.log_softmax(logits)
            loss = -torch.sum(logs * lb_one_hot, dim=1)
            loss[ignore] = 0
            if self.reduction == 'mean':
                loss = loss.sum() / n_valid
            if self.reduction == 'sum':
                loss = loss.sum()
    
            return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    三、参考资料

    1. https://paperswithcode.com/method/label-smoothing
    2. https://arxiv.org/pdf/1906.02629.pdf
    3. https://github.com/CoinCheung/pytorch-loss/blob/master/label_smooth.py
  • 相关阅读:
    动画一:过渡(超详细!)
    python的简单爬取
    外贸公司职业保密协议
    python Jupyter程序之Matplotlib数据可视化
    阿桂天山的技术小结:Flask实现对Ztree树状节点的增改删操作
    Elixir-Tuples
    投影坐标系的shp数据,如何过去到它地理坐标系下的经纬度坐标
    [笔记]SSH 端口转发
    SaaSBase:什么是汇思?
    46届世界技能大赛湖北省选拔赛wp 3.0
  • 原文地址:https://blog.csdn.net/qq_40765537/article/details/125107238