• PyTorch交叉熵理解


    PyTorch 中的交叉熵损失

    CrossEntropyLoss

    PyTorch 中使用CrossEntropyLoss 计算交叉熵损失,常用于分类任务。交叉熵损失衡量了模型输出的概率分布与实际标签分布之间的差异,目标是最小化该损失以优化模型。

    我们通过一个具体的案例来详细说明 CrossEntropyLoss 的计算过程。

    假设我们有一个简单的分类任务,共有 3 个类别。我们有 2 个样本的预测和实际标签。

    输入

    • 模型的预测(logits,未经过 softmax 激活)

    • 实际标签

    import torch
    import torch.nn as nn
    
    # 模型的预测(logits)
    logits = torch.tensor([[2.0, 1.0, 0.1],
                           [0.5, 2.0, 0.3]])
    
    # 实际标签
    labels = torch.tensor([0, 2])
    

    计算步骤

    • 步骤 1: Softmax 激活

    首先,将 logits 通过 softmax 激活函数转换为概率分布

    softmax = nn.Softmax(dim=1)
    probabilities = softmax(logits)
    print(probabilities)
    

    输出

    tensor([[0.6590, 0.2424, 0.0986],
            [0.1587, 0.7113, 0.1299]])
    
    • 步骤 2: 计算交叉熵

    交叉熵损失的计算公式为:

    C r o s s E n t r o p y L o s s = − ∑ i = 1 N log ⁡ ( p i , y i ) CrossEntropyLoss=-\sum_{i=1}^{N}{\log{(}}{{p}_{i,{{y}_{i}}}}) CrossEntropyLoss=i=1Nlog(pi,yi)

    其中 N 是样本数量, p i , y i p_{i,y_i} pi,yi是第 i个样本在实际标签  y i y_i yi 位置上的预测概率。

    我们手动计算每个样本的交叉熵损失:

    • 对于第一个样本,实际标签为 0,预测概率为 0.6590

    l o s s 1 = − log ⁡ ( 0.6590 ) ≈ 0.4171 {{loss}_{1}}=-\log{(}0.6590)\approx 0.4171 loss1=log(0.6590)0.4171

    • 对于第二个样本,实际标签为 2,预测概率为 0.1299

    l o s s 2 = − log ⁡ ( 0.1299 ) ≈ 2.0406 {{loss}_{2}}=-\log{(}0.1299)\approx 2.0406 loss2=log(0.1299)2.0406

    平均损失为:

    m e a n = 0.4171 + 2.0406 2 ≈ 1.2288 mean=\frac{0.4171+2.0406}{2}\approx 1.2288 mean=20.4171+2.04061.2288

    • 步骤 3: 使用 PyTorch 的 CrossEntropyLoss 计算

    我们使用 PyTorch 的 CrossEntropyLoss 函数来验证计算结果:

    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, labels)
    print(loss.item())
    

    输出

    1.2288230657577515
    
    • 步骤4:依据公式使用 PyTorch 计算

    依据前面的公式使用 PyTorch 计算来验算结果

    neg_log_p = -torch.log(probabilities)
    loss_cal = neg_log_p[torch.arange(neg_log_p.shape[0]), labels].mean()
    print(loss_cal.item())
    

    输出

    1.228823184967041
    

    结果基本一致。

    总结

    1. CrossEntropyLoss 接受未经过 softmax 的 logits 作为输入。

    2. 内部首先对 logits 应用 softmax,将其转换为概率分布。

    3. 然后根据实际标签计算交叉熵损失。

  • 相关阅读:
    Redis 非关系型数据库学习(一) ---- Redis 的安装
    队列的实现
    Mall脚手架总结(三) —— MongoDB存储浏览数据
    [CCS] 没有Runtime Object View(ROV)怎么办?
    高等数学(第七版)同济大学 习题7-7 个人解答
    【实操日记】使用 PyQt5 设计下载远程服务器日志文件程序
    【亲测有效】hadoop hive1,hive2 索引加速查询 hive sql优化 大幅优化查询速度 索引建立
    打造店铺爆款的玩法方式解析
    27-spark各版本对比
    尚硅谷-云尚办公-项目复盘
  • 原文地址:https://blog.csdn.net/mp9105/article/details/139421554