• 深入浅出PyTorch中的nn.CrossEntropyLoss


    👨‍💻 作者简介:非科班转码,正在不断丰富自己的技术栈
    🗒️ 博客主页:https://raelum.blog.csdn.net
    🎯 主要领域:NLP、RS、GNN
    📢 如果这篇文章有帮助到你,可以关注❤️ + 点赞👍 + 收藏⭐ + 留言💬,这将是我创作的最大动力

    在这里插入图片描述

    一、前言

    nn.CrossEntropyLoss 常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。

    import torch
    import torch.nn as nn
    
    • 1
    • 2

    二、理论基础

    对于 C   ( C > 2 ) C\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 { x c } c = 1 C \{x_c\}_{c=1}^C {xc}c=1C,经过 Softmax 后得到

    q i = exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=c=1Cexp(xc)exp(xi)

    从而该样本的交叉熵损失为

    H ( p , q ) = − ∑ i = 1 C p i log ⁡ q i = − ∑ i = 1 C p i log ⁡ exp ⁡ ( x i ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=i=1Cpilogqi=i=1Cpilogc=1Cexp(xc)exp(xi)

    其中 ( p 1 , p 2 , ⋯   , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,,pC) 是 One-Hot 向量。

    不妨令 p y = 1   ( y ∈ { 1 , 2 , ⋯   , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y{1,2,,C}),其余为 0 0 0,因此上式变为

    H ( p , q ) = − log ⁡ exp ⁡ ( x y ) ∑ c = 1 C exp ⁡ ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=logc=1Cexp(xc)exp(xy)

    现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 { x n c } n c ,    n = 1 , ⋯   , N ,    c = 1 , ⋯   , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C {xnc}nc,n=1,,N,c=1,,C,第 n n n 个样本的真实类别记为 y n   ( y n ∈ { 1 , 2 , ⋯   , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn{1,2,,C}),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有

    l n = − log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=logc=1Cexp(xnc)exp(xn,yn)

    接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯   , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,,wC)

    📌 模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚

    安排了权重后,相应的损失为

    l n = − w y n log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=wynlogc=1Cexp(xnc)exp(xn,yn)

    计算完 l 1 , l 2 , ⋯   , l N l_1,l_2,\cdots,l_N l1,l2,,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none),也可以返回它们的均值(对应 reduction=mean),还可以返回它们的(对应 reduction=sum):

    ℓ = { ( l 1 , ⋯   , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell={(l1,,lN),reduction=noneNn=1ln/Nn=1wyn,reduction=meanNn=1ln,reduction=sum

    (l1,,lN),Nn=1ln/Nn=1wyn,Nn=1ln,reduction=nonereduction=meanreduction=sum
    =(l1,,lN),n=1Nln/n=1Nwyn,n=1Nln,reduction=nonereduction=meanreduction=sum

    在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为 i i i,则此时应对 l n l_n ln 作如下修正:

    l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log ⁡ exp ⁡ ( x n , y n ) ∑ c = 1 C exp ⁡ ( x n c ) , where    I ( x ) = { 1 , x    is True 0 , x    is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= {1,xis True0,xis False

    ln=wynI(yn=i)logc=1Cexp(xnc)exp(xn,yn),whereI(x)={1,0,xis Truexis False

    另外,该场景下的 reduction=mean 对应的损失变为

    ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} =n=1Nn=1NwynI(yn=i)ln

    📌 需要注意的是,在PyTorch中 y n ∈ { 0 , 1 , ⋯   , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn{0,1,,C1},这里我们之所以用 { 1 , 2 , ⋯   , C } \{1,2,\cdots,C\} {1,2,,C} 是为了更自然地衔接上下文

    三、主要参数

    nn.CrossEntropyLoss 的主要参数如下:

    nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
    
    • 1

    ⚠️ size_averagereduce 参数已经弃用,取而代之的是 reduction 参数,所以这里不再讲解


    有了前面的铺垫,我们就可以很容易理解这些参数了:

    • weight:长度为 C C C 的张量,一般在数据不平衡时才会使用;
    • ignore_index:需要忽略的类别的索引,默认为 − 100 -100 100,即不忽略;
    • reduction:决定以何种形式返回损失。为 none 时返回 N N N 个样本的损失,为 mean 时返回 N N N 个样本的损失均值,为 sum 时返回 N N N 个样本的损失的和。默认为 mean
    • label_smoothing:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。

    3.1 输入与输出

    输入分为 inputtargetinput 通常为 ( N , C ) (N,C) (N,C) 的形状(即 batch_size × num_classes),target 通常为 ( N , ) (N,) (N,) 的形状,其中的每个分量均位于 [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C1]Z 中,代表样本属于的类别。

    📌 inputtarget 还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入
    📌 input 是神经网络的原始输出(未经过 Softmax),nn.CrossEntropyLoss 会自动对其应用 Softmax

    torch.manual_seed(0)
    batch_size = 3
    num_classes = 5
    criterion_1 = nn.CrossEntropyLoss(reduction='none')
    criterion_2 = nn.CrossEntropyLoss()
    criterion_3 = nn.CrossEntropyLoss(reduction='sum')
    
    inputs = torch.randn(batch_size, num_classes)  # 避免与input关键字冲突(当然这无所谓)
    target = torch.randint(num_classes, size=(batch_size, ))
    
    print(criterion_1(inputs, target))  # 输出3个样本的loss
    # tensor([1.4639, 3.0493, 2.3056])
    print(criterion_2(inputs, target))  # 输出3个样本的loss的均值
    # tensor(2.2729)
    print(criterion_3(inputs, target))  # 输出3个样本的loss的和
    # tensor(6.8188)
    
    print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
    # tensor(True)
    print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
    # tensor(True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    四、从零开始实现 nn.CrossEntropyLoss

    为了加深理解,接下来我们从零开始实现 nn.CrossEntropyLoss(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。

    首先确定框架(为简便起见这里不考虑 label_smoothing):

    class CrossEntropyLoss(nn.Module):
    
        def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
            super().__init__()
            self.weight = weight
            self.ignore_index = ignore_index
            self.reduction = reduction
            
        def forward(self, inputs, target):
            pass
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    为方便计算,我们对第二章节的损失计算公式进行改写

    l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ⁡ ∑ c = 1 C exp ⁡ ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wynI(yn=i)[xn,yn+logc=1Cexp(xnc)]

    采用更符合 Python 的表述方式来改写上式

    l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ⁡ ∑ c = 1 C exp ⁡ ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]I(yn=i)[xn[yn]+logc=1Cexp(xn[c])]

    其中 w = ( w 1 , ⋯   , w C ) ,    x n = ( x n 1 , ⋯   , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,,wC),xn=(xn1,,xnC)。再令 X = ( x 1 ; ⋯   ; x N ) ,    y = ( y 1 , ⋯   , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;;xN),y=(y1,,yC),则显然 X {\bf X} X 就是我们的 input y \boldsymbol{y} y 就是 target,于是我们可以进行批量计算

    ( l 1 , ⋯   , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) ,   y ] + log ⁡ ( sum ( exp ⁡ ( X ) ,   dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,,lN)=w[y]I(y=i)(X[range(len(y)),y]+log(sum(exp(X),dim=1)))

    其中 ∗ * 代表按元素相乘。上式采用了广播机制。

    class CrossEntropyLoss(nn.Module):
    
        def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
            super().__init__()
            self.weight = weight
            self.ignore_index = ignore_index
            self.reduction = reduction
    
        def forward(self, inputs, target):
            if self.weight is not None:
                n_samples_weight = self.weight[target]  # 每个样本的权重
            else:
                n_samples_weight = torch.ones_like(target).float()  # 不提供权重则默认全为1
            indicator = (target != self.ignore_index).long().float()  # long()方法可以将布尔型张量转化成0-1张量
            raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
            result = n_samples_weight * indicator * raw_loss
            if self.reduction == 'mean':
                return torch.sum(result) / n_samples_weight.dot(indicator)
            elif self.reduction == 'sum':
                return torch.sum(result)
            else:
                return result
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    输出结果与 PyTorch 官方的 nn.CrossEntropyLoss 的完全相同,这里不再展示,读者可自行验证。

  • 相关阅读:
    Wpf 使用 Prism 实战开发Day04
    Linux学习之HTTP
    Java面试题(每天10题)-------连载(32)
    前端开发如何更好的避免样式冲突?级联层(CSS@layer)
    应用程序通过 Envoy 代理和 Jaeger 进行分布式追踪(一)
    基于开源库libreDWG+Java实现AutoCad格式DWG转DXF
    APP 开发方式的优缺点有哪些?
    2023最新SSM计算机毕业设计选题大全(附源码+LW)之java计算机专业建设管理系统3286d
    JDK1.8之前与之后 HashMap底层实现原理的差别
    蓝桥杯:等差数列
  • 原文地址:https://blog.csdn.net/raelum/article/details/125588956