• 聊一聊损失函数


    聊一聊损失函数

    前言

    损失函数,具体来说就是计算神经网络每次迭代的前向计算结果与真实值的差距,从而指导下一步的训练向正确的方向进行。下面主要介绍一些常见的损失函数:均方差损失函数交叉熵损失函数

    均方差损失函数

    均方误差损失(Mean Square Error,MSE)又称为二次损失、L2 损失,常用于回归预测任务中。均方误差函数通过计算预测值和实际值之间距离(即误差)的平方来衡量模型优劣。即预测值和真实值越接近,两者的均方差就越小。

    均方差函数常用于线性回归(linear regrWession),即函数拟合(function fitting)。公式如下:

    l o s s = 1 2 ( a − y ) 2 (单样本) loss = {1 \over 2}(a-y)^2 \tag{单样本} loss=21(ay)2(单样本)

    J = 1 2 m ∑ i = 1 m ( a i − y i ) 2 (多样本) J=\frac{1}{2m} \sum_{i=1}^m (a_i-y_i)^2 \tag{多样本} J=2m1i=1m(aiyi)2(多样本)

    均方差函数比较简单,也较为常见,这里就不多说了。

    交叉熵损失函数

    交叉熵(Cross Entropy)是 Shannon 信息论中一个重要概念,主要用于度量两个概率分布间的差异性信息。在信息论中,交叉熵是表示两个概率分布 p , q p,q p,q 的差异,其中 p p p 表示真实分布, q q q 表示预测分布,那么 H ( p , q ) H(p,q) H(p,q) 就称为交叉熵:

    H ( p , q ) = ∑ i p i ⋅ ln ⁡ 1 q i = − ∑ i p i ln ⁡ q i H(p,q)=\sum_i p_i \cdot \ln \frac{1}{q_i} = -\sum_i p_i \ln q_i H(p,q)=ipilnqi1=ipilnqi

    交叉熵可在神经网络中作为损失函数, p p p 表示真实标记的分布, q q q 则为训练后的模型的预测标记分布,交叉熵损失函数可以衡量 p p p q q q 的相似性。

    交叉熵函数常用于逻辑回归(logistic regression),也就是分类(classification)。

    信息量

    信息量来衡量一个事件的不确定性,一个事件发生的概率越大,不确定性越小,则其携带的信息量就越小。在信息论中,可以通过如下方式表示:

    I ( x j ) = − ln ⁡ ( p ( x j ) ) I(x_j) = -\ln(p(x_j)) I(xj)=ln(p(xj))

    其中 x j x_j xj表示一个事件, p ( x j ) p(x_j) p(xj)表示 x j x_j xj发生的概率。

    举个例子,对于下面这三个事件,可以通过概率计算其信息量:

    事件编号事件概率信息量
    x 1 x_1 x1优秀 p = 0.7 p=0.7 p=0.7 I = − ln ⁡ ( 0.7 ) = 0.36 I=-\ln(0.7)=0.36 I=ln(0.7)=0.36
    x 2 x_2 x2及格 p = 0.2 p=0.2 p=0.2 I = − ln ⁡ ( 0.2 ) = 1.61 I=-\ln(0.2)=1.61 I=ln(0.2)=1.61
    x 3 x_3 x3不及格 p = 0.1 p=0.1 p=0.1 I = − ln ⁡ ( 0.1 ) = 2.30 I=-\ln(0.1)=2.30 I=ln(0.1)=2.30

    事件发生的概率越小,其信息量越大。

    熵用来衡量一个系统的混乱程度,代表系统中信息量的总和;熵值越大,表明这个系统的不确定性就越大。具体来说:

    H ( p ) = − ∑ j n p ( x j ) ln ⁡ ( p ( x j ) ) H(p)=-\sum_j^n p(x_j)\ln(p(x_j)) H(p)=jnp(xj)ln(p(xj))

    其中 p ( x j ) p(x_j) p(xj)表示 x j x_j xj发生的概率, − ln ⁡ ( p ( x j ) ) -\ln(p(x_j)) ln(p(xj))表示事件的信息量。

    信息量是衡量某个事件的不确定性,而熵是衡量一个系统(所有事件)的不确定性。

    对于上面的例子,我们可以计算其熵:

    H ( p ) = − [ p ( x 1 ) ln ⁡ p ( x 1 ) + p ( x 2 ) ln ⁡ p ( x 2 ) + p ( x 3 ) ln ⁡ p ( x 3 ) ] = 0.7 × 0.36 + 0.2 × 1.61 + 0.1 × 2.30 = 0.804

    H(p)=[p(x1)lnp(x1)+p(x2)lnp(x2)+p(x3)lnp(x3)]=0.7×0.36+0.2×1.61+0.1×2.30=0.804" role="presentation" style="position: relative;">H(p)=[p(x1)lnp(x1)+p(x2)lnp(x2)+p(x3)lnp(x3)]=0.7×0.36+0.2×1.61+0.1×2.30=0.804
    H(p)=[p(x1)lnp(x1)+p(x2)lnp(x2)+p(x3)lnp(x3)]=0.7×0.36+0.2×1.61+0.1×2.30=0.804

    相对熵(KL 散度)

    相对熵也称为 KL 散度(Kullback-Leibler divergence),表示同一个随机变量的两个不同分布间的距离,相当于信息论范畴的均方差。

    p ( x ) , q ( x ) p(x),q(x) p(x),q(x)分别是随机变量 x x x的两个概率分布,则 p p p q q q的相对熵计算如下:

    D K L ( p ∥ q ) = ∑ j = 1 n p ( x j ) ln ⁡ p ( x j ) q ( x j ) D_{KL}(p \Vert q) = \sum_{j=1}^n p(x_j) \ln \frac{p(x_j)}{q(x_j)} DKL(pq)=j=1np(xj)lnq(xj)p(xj)

    其中 n n n为事件的所有可能性。相对熵 D D D的值越小,表示两个分布越接近。在实际应用中,假如 p ( x ) p(x) p(x)是目标真实的分布,而 q ( x ) q(x) q(x)是预测得来的分布,为了让这两个分布尽可能的相同的,就需要最小化 KL 散度。

    交叉熵

    将上述公式变形:

    D K L ( p ∥ q ) = ∑ j = 1 n p ( x j ) ln ⁡ p ( x j ) q ( x j ) = ∑ j = 1 n p ( x j ) ln ⁡ p ( x j ) − ∑ j = 1 n p ( x j ) ln ⁡ q ( x j ) = − H ( p ( x ) ) + H ( p , q )

    DKL(pq)=j=1np(xj)lnp(xj)q(xj)=j=1np(xj)lnp(xj)j=1np(xj)lnq(xj)=H(p(x))+H(p,q)" role="presentation" style="position: relative;">DKL(pq)=j=1np(xj)lnp(xj)q(xj)=j=1np(xj)lnp(xj)j=1np(xj)lnq(xj)=H(p(x))+H(p,q)
    DKL(pq)=j=1np(xj)lnq(xj)p(xj)=j=1np(xj)lnp(xj)j=1np(xj)lnq(xj)=H(p(x))+H(p,q)

    其中,等式的前一部分就是 p p p的熵,后一部分就是交叉熵:

    H ( p , q ) = − ∑ j = 1 n p ( x j ) ln ⁡ q ( x j ) H(p,q) = -\sum_{j=1}^n p(x_j) \ln q(x_j) H(p,q)=j=1np(xj)lnq(xj)

    在机器学习中,我们需要评估标签值 y y y和预测值 a a a之间的差距,就可以计算 D K L ( p ∥ q ) D_{KL}(p \Vert q) DKL(pq),由于 H ( y ) H(y) H(y)不变,因此在优化过程中只需要考虑交叉熵即可。对于单样本计算如下:

    l o s s = − ∑ j = 1 n y j ln ⁡ a j loss = -\sum_{j=1}^n y_j \ln a_j loss=j=1nyjlnaj

    对于批量样本的交叉熵计算如下:

    J = − ∑ i = 1 m ∑ j = 1 n y i j ln ⁡ a i J= -\sum_{i=1}^m\sum_{j=1}^n y_{ij} \ln a_{i} J=i=1mj=1nyijlnai

    其中 m m m为样本数, n n n为分类数。

    二分类问题交叉熵

    在二分的情况下,通常使用sigmoid 将输出映射为正样本的概率,对于每个类别我们的预测的到的概率为 a a a 1 − a 1-a 1a,所以交叉熵可以简化为:

    l o s s = − [ y ln ⁡ a + ( 1 − y ) ln ⁡ ( 1 − a ) ] loss = -[y\ln a + (1-y)\ln(1-a)] loss=[ylna+(1y)ln(1a)]

    二分类对于批量样本的交叉熵计算公式:

    J = − ∑ i = 1 m [ y i ln ⁡ a i + ( 1 − y i ) ln ⁡ ( 1 − a i ) ] J = -\sum_{i=1}^m [y_i\ln a_i + (1-y_i)\ln(1-a_i)] J=i=1m[yilnai+(1yi)ln(1ai)]

    简单分析一下公式,可以发现,当 y = 1 y=1 y=1时为正样本, l o s s = − ln ⁡ ( a ) loss=-\ln(a) loss=ln(a);当 y = 0 y=0 y=0时为负样本, l o s s = − ln ⁡ ( 1 − a ) loss=-\ln(1-a) loss=ln(1a)

    事件编号预测值 a a a真实值 y y y
    x 1 x_1 x10.61
    x 2 x_2 x20.71

    举个例子,对于上面的情况,我们分别计算其交叉熵损失:

    l o s s 1 = − ( 1 × ln ⁡ 0.6 + ( 1 − 1 ) × ln ⁡ ( 1 − 0.6 ) ) = 0.51 loss_1 = -(1 \times \ln 0.6 + (1-1) \times \ln (1-0.6)) = 0.51 loss1=(1×ln0.6+(11)×ln(10.6))=0.51

    l o s s 2 = − ( 1 × ln ⁡ 0.7 + ( 1 − 1 ) × ln ⁡ ( 1 − 0.7 ) ) = 0.36 loss_2 = -(1 \times \ln 0.7 + (1-1) \times \ln (1-0.7)) = 0.36 loss2=(1×ln0.7+(11)×ln(10.7))=0.36

    计算得到 l o s s 1 > l o s s 2 loss_1 > loss_2 loss1>loss2,相应的 l o s s 2 loss_2 loss2反向传播的力度也会小。

    多分类问题交叉熵

    多分类问题也是类似的,考虑下面的优秀、及格、不及格分类:

    事件编号 p ( x 1 ) = 优秀 p(x_1)=优秀 p(x1)=优秀 p ( x 1 ) = 及格 p(x_1)=及格 p(x1)=及格 p ( x 1 ) = 不及格 p(x_1)=不及格 p(x1)=不及格真实值 y y y
    x 1 x_1 x10.20.50.3不及格
    x 2 x_2 x20.20.20.6不及格

    举个例子,对于上面的情况,我们分别计算其交叉熵损失:

    l o s s 1 = − ( 0 × ln ⁡ 0.2 + 0 × ln ⁡ 0.5 + 1 × ln ⁡ 0.3 ) = 1.2 loss_1 = -(0 \times \ln 0.2 + 0 \times \ln 0.5 + 1 \times \ln 0.3) = 1.2 loss1=(0×ln0.2+0×ln0.5+1×ln0.3)=1.2

    l o s s 2 = − ( 0 × ln ⁡ 0.2 + 0 × ln ⁡ 0.2 + 1 × ln ⁡ 0.6 ) = 0.51 loss_2 = -(0 \times \ln 0.2 + 0 \times \ln 0.2 + 1 \times \ln 0.6) = 0.51 loss2=(0×ln0.2+0×ln0.2+1×ln0.6)=0.51

    计算得到 l o s s 1 > l o s s 2 loss_1 > loss_2 loss1>loss2,相应的 l o s s 2 loss_2 loss2反向传播的力度也会小。

    PyTorch 实现

    在 PyTorch 中,常用的损失函数我们可以直接调用:

    • nn.MSELoss()
    • nn.CrossEntropyLoss()

    但有时我们会需要自定义损失函数,这时我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类。

    import torch
    import torch.nn as nn
    
    class myLoss(nn.Module):
        def __init__(self,parameters)
            self.params = self.parameters
    
        def forward(self)
            loss = cal_loss(self.params)
            return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    参考资料

  • 相关阅读:
    使用pytorch将三维图像分块(patches)
    K8s集群调度
    MYSQL逻辑架构
    《vector和list 的对比》
    2024年宝鸡市国家级、省级、市级科技企业孵化器申报奖励补贴标准及申报条件
    asp毕业设计——基于asp+access的学生管理系统设计与实现(毕业论文+程序源码)——学生管理系统
    读新乌合之众
    REDIS04_主从复制概述及搭建、反客为主、薪火相传、原理、哨兵模式、集群搭建
    图像处理——图像增强
    【Java中23种面试常考的设计模式之桥接模式(Bridge)---结构型模式】
  • 原文地址:https://blog.csdn.net/Lamours/article/details/125895288