• 多标签分类损失函数/精度 BCEWithLogitsLoss MultiLabelSoftMarginLoss BCELoss


     实现源码 

    1. import torch
    2. import numpy as np
    3. pred = np.array([[-0.4089, -1.2471, 0.5907],
    4. [-0.4897, -0.8267, -0.7349],
    5. [0.5241, -0.1246, -0.4751]])
    6. label = np.array([[0, 1, 1],
    7. [0, 0, 1],
    8. [1, 0, 1]])
    9. pred = torch.from_numpy(pred).float()
    10. label = torch.from_numpy(label).float()
    11. ## 通过BCEWithLogitsLoss直接计算输入值(pick)
    12. crition1 = torch.nn.BCEWithLogitsLoss()
    13. loss1 = crition1(pred, label)
    14. print(loss1)
    15. crition2 = torch.nn.MultiLabelSoftMarginLoss()
    16. loss2 = crition2(pred, label)
    17. print(loss2)
    18. ## 通过BCELoss计算sigmoid处理后的值
    19. crition3 = torch.nn.BCELoss()
    20. loss3 = crition3(torch.sigmoid(pred), label)
    21. print(loss3)

    关于BCEWithLogitsLoss 

    这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下:

    1. import torch
    2. import torch.nn as nn
    3. label = torch.Tensor([1, 1, 0])
    4. pred = torch.Tensor([3, 2, 1])
    5. pred_sig = torch.sigmoid(pred)
    6. loss = nn.BCELoss()
    7. print(loss(pred_sig, label))
    8. loss = nn.BCEWithLogitsLoss()
    9. print(loss(pred, label))
    10. loss = nn.BCEWithLogitsLoss()
    11. print(loss(pred_sig, label))
    12. 输出结果分别为:
    13. tensor(0.4963)
    14. tensor(0.4963)
    15. tensor(0.5990)

    可以看到,nn.BCEWithLogitsLoss()相当于是在nn.BCELoss()中预测结果pred的基础上先做了个sigmoid,然后继续正常算loss。所以这就涉及到一个比较奇葩的bug,如果网络本身在输出结果的时候已经用sigmoid去处理了,算loss的时候用nn.BCEWithLogitsLoss()…那么就会相当于预测结果算了两次sigmoid,可能会出现各种奇奇怪怪的问题——

    比如网络收敛不了

    原文链接:https://blog.csdn.net/qq_40714949/article/details/120295651

    MultiLabelSoftMarginLoss

    不知道pytorch为什么起这个名字,看loss计算公式,并没有涉及到margin,有可能后面会实现。按照我的理解其实就是多标签交叉熵损失函数,验证之后也和BCEWithLogitsLoss的结果输出一致,使用的torch版本为1.5.0

    原文链接:https://blog.csdn.net/ltochange/article/details/118070885

    1. import torch
    2. import torch.nn.functional as F
    3. import torch.nn as nn
    4. import math
    5. def validate_loss(output, target, weight=None, pos_weight=None):
    6. output = F.sigmoid(output)
    7. # 处理正负样本不均衡问题
    8. if pos_weight is None:
    9. label_size = output.size()[1]
    10. pos_weight = torch.ones(label_size)
    11. # 处理多标签不平衡问题
    12. if weight is None:
    13. label_size = output.size()[1]
    14. weight = torch.ones(label_size)
    15. val = 0
    16. for li_x, li_y in zip(output, target):
    17. for i, xy in enumerate(zip(li_x, li_y)):
    18. x, y = xy
    19. loss_val = pos_weight[i] * y * math.log(x, math.e) + (1 - y) * math.log(1 - x, math.e)
    20. val += weight[i] * loss_val
    21. return -val / (output.size()[0] * output.size(1))
    22. weight = torch.Tensor([0.8, 1, 0.8])
    23. loss = nn.MultiLabelSoftMarginLoss(weight=weight)
    24. x = torch.Tensor([[0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3], [0.8, 0.9, 0.3]])
    25. y = torch.Tensor([[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0]])
    26. print(x.size())
    27. print(y.size())
    28. loss_val = loss(x, y)
    29. print(loss_val.item())
    30. validate_loss = validate_loss(x, y, weight=weight)
    31. print(validate_loss.item())
    32. loss = torch.nn.BCEWithLogitsLoss(weight=weight)
    33. loss_val = loss(x, y)
    34. print(loss_val.item())
    35. # 输出
    36. torch.Size([4, 3])
    37. torch.Size([4, 3])
    38. 0.4405062198638916
    39. 0.4405062198638916
    40. 0.440506249666214

    BCELoss

    loss函数之BCELoss - 简书 (jianshu.com)

    精度计算

    2 准确率计算
    依然是上面的例子,模型的输出是[0.2,0.6,0.8],真实值是[0,0,1]。准确率该怎么计算呢?

    1. pred = torch.tensor([0.2, 0.6, 0.8])
    2. y = torch.tensor([0, 0, 1])
    3. accuracy = (pred.ge(0.5) == y).all().int().item()
    4. accuracy
    5. # output : 0

    首先ge函数将pred中大于等于0.5的转化为True,小于0.5的转化成False,再比较pred和y(必须所有维度都相同才算分类准确),最后将逻辑值转化为整数输出即可。
    训练时都是按照一个batch计算的,那就写一个循环吧。

    1. pred = torch.tensor([[0.2, 0.5, 0.8], [0.4, 0.7, 0.1]])
    2. y = torch.tensor([[0, 0, 1], [0, 1, 0]])
    3. accuracy = sum(row.all().int().item() for row in (pred.ge(0.5) == y))
    4. accuracy
    5. # output : 1



    原文链接:https://blog.csdn.net/qsmx666/article/details/121718548

  • 相关阅读:
    Android焦点控制和键盘弹出
    电子学会青少年软件编程 Python编程等级考试三级真题解析(选择题)2021年3月
    js中各种数据类型检测与判定
    水一下文章
    dedecms织梦管理系统模板标签代码
    dubbo项目整合nacos注册中心问题记录
    前端性能优化
    力扣+牛客--刷题记录
    vue-生成二维码
    跳表论文解读
  • 原文地址:https://blog.csdn.net/weixin_41803874/article/details/125411154