• python交叉熵nn.CrossEntropyLoss的计算过程及意义解释


    简述计算过程

    参考:https://blog.csdn.net/liangjiu2009/article/details/107769512
    这里引用了参数连接中的图片,图片中求对数的工程应该是ln,自然对数
    在这里插入图片描述
    一、假设一个矩阵

    X = [ x 1 x 2 x 3 x 4 ] \bm X=

    [x1x2x3x4]" role="presentation">[x1x2x3x4]
    X=[x1x3x2x4]
    二、对矩阵数据中的每个进行softmax计算
    每一行的每一个参数求(以第一行为例)
    t i = e x i ∑ i = 1 2 e i x (1) \bm t_i = \dfrac{e^{x_i}}{\sum\limits_{i=1}^2 e^x_i} \tag{1} ti=i=12eixexi(1)
    最终X矩阵转化为,
    y = [ t 1 t 2 t 3 t 4 ] (2) \bm y=
    [t1t2t3t4]" role="presentation">[t1t2t3t4]
    \tag{2}
    y=[t1t3t2t4](2)

    三、对softmax结果的矩阵每个进行log操作
    对每个数据进行自然对数的log计算
    f = [ l n   t 1 l n   t 2 l n   t 3 l n   t 4 ] = [ x 1 − l n ∑ i = 0 4 e i x x 2 − l n ∑ i = 0 4 e i x x 3 − l n ∑ i = 3 4 e i x x 4 − l n ∑ i = 3 4 e i x ] (2) \bm f=
    [ln t1ln t2ln t3ln t4]" role="presentation">[ln t1ln t2ln t3ln t4]
    =
    [x1lni=04eixx2lni=04eixx3lni=34eixx4lni=34eix]" role="presentation">[x1lni=04eixx2lni=04eixx3lni=34eixx4lni=34eix]
    \tag{2}
    f=[ln t1ln t3ln t2ln t4]=x1lni=04eixx3lni=34eixx2lni=04eixx4lni=34eix(2)

    四、根据目标的索引位置提取出矩阵中的位置,取出值然后取反,然后取平均
    如果目标矩阵的索引是[0,1],那么就需要取出 x 1 − l n ∑ i = 1 2 x i x_1- ln{\sum\limits_{i=1}^2 x_i} x1lni=12xi x 4 − n ∑ i = 3 4 x i x_4- n{\sum\limits_{i=3}^4 x_i} x4ni=34xi,然后将两个值进行求平均后取反:
    − ( x 1 − l n ∑ i = 1 2 e i x + x 4 − l n e ∑ i = 3 4 e i x ) / 2 (3) -(x_1- ln{\sum\limits_{i=1}^2 e^x_i}+x_4-ln_e{\sum\limits_{i=3}^4 e^x_i})/2\tag{3} (x1lni=12eix+x4lnei=34eix)/2(3)

    一个简单的实例

    import torch.nn as nn
    import torch
    
    x = torch.tensor([[1.0,2.0],[3.0,4.0]])
    y = torch.tensor([0,1])
    
    loss_func = nn.CrossEntropyLoss()
    print(loss_func(x,y))  # tensor(0.8133)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    在这里插入图片描述
    可以实验一下,是按照每行来计算的,若使用两行三列的输入,
    [ 1 2 3 4 5 6 ] (4)

    [123456]" role="presentation" style="position: relative;">[123456]
    \tag{4} [142536](4)
    从上面的解释可以看出,若按照行这标签中的onehot编码值不能大于2,且只有两个标签值。通过改变标签的数量和大小的报错情况判断是按照行的计算。

    交叉熵的意义

    例如,对两个字符的索引预测,输入就相当于每个可能字符 的概率,标签就相当于正确字符对应的索引。那么**输入的概率必定在01之间**,及每个x的值在01之间,那么下面的函数:(c代表不包含当前位置的求和)
    f ( x m ) = e x m ∑ i = 0 n e i x = e x m e x m + ∑ i ≠ m n e i x = e x m e x m + c (5) f(x_m)= \dfrac{e^{x_m}}{\sum\limits_{i=0}^n e^x_i}= \dfrac{e^{x_m}}{e^{x_m}+\sum\limits_{i\neq m}^n e^x_i}=\dfrac{e^{x_m}}{e^{x_m}+c} \tag{5} f(xm)=i=0neixexm=exm+i=mneixexm=exm+cexm(5)
    将公式5简记为:
    f ( x ) = e x e x + c (6) f(x)=\dfrac{e^{x}}{e^{x}+c} \tag{6} f(x)=ex+cex(6)
    将6式求导可知:
    f ′ ( x ) = e x ( e x + c ) − e x ∗ e x ( e x + c ) 2 ≥ 0 (6) f'(x)=\dfrac{e^{x}(e^{x}+c)-e^x*e^x}{(e^{x}+c)^2}\geq0 \tag{6} f(x)=(ex+c)2ex(ex+c)exex0(6)
    所以,在0~1范围内, f ( x ) f(x) f(x)单调递增,且 0 < f ( x ) < 1 0< f(x)<1 0<f(x)<1对f(x)求自然对数,**也是单调递增,但取的对数值小于0,再取反后是大于0但变为单调递减。**这也就是我们最终交叉熵损失函数的趋势。
    网络中的损失函数都是为了求得最小值,则随着x增大而减小的损失函数刚好满足要求,而x最大也是接近1,其他位置的概率接近0,而pytorch中的交叉熵会对上述结果进行相加求均值

  • 相关阅读:
    【2023 · CANN训练营第一季】基于昇腾910的TF网络脚本训练(ModelArts平台)
    与HTTP相关的各种概念
    java.io.IOException: Server returned HTTP response code: 403 for URL
    web爬虫第二弹 chrome开发者工具
    Shell 相对路径转换为绝对路径
    Windows和Linux环境中安装Zookeeper具体操作
    #前端#scss学习
    Unity DOTS技术(八)状态组件
    Elasticsearch搜索辅助功能解析(十)
    IE停止维护 导致 @vue/cli-plugin-babel 编译失败
  • 原文地址:https://blog.csdn.net/weixin_43794311/article/details/127796286