• TensorFlow入门(二十一、softmax算法与损失函数)


    在实际使用softmax计算loss时,有一些关键地方与具体用法需要注意:

            交叉熵是十分常用的,且在TensorFlow中被封装成了多个版本。多版本中,有的公式里直接带了交叉熵,有的需要自己单独手写公式求出。如果区分不清楚,在构建模型时,一旦出现问题将很难分析是模型的问题还是交叉熵的使用问题。

    示例代码如下:

    1. import tensorflow as tf
    2. #labels和logits的shape一样
    3. #定义one-hot标签数据
    4. labels = [[0,0,1],[0,1,0]]
    5. #定义预测数据
    6. logits = [[2,0.5,6],[0.1,0,3]]
    7. #对预测数据求一次softmax值
    8. logits_scaled = tf.nn.softmax(logits)
    9. #在求交叉熵的基础上求第二次的softmax值
    10. logits_scaled2 = tf.nn.softmax(logits_scaled)
    11. #使用API求交叉熵
    12. #对预测数据与标签数据计算交叉熵
    13. result1 = tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = logits)
    14. #对第一次的softmax值与标签数据计算交叉熵
    15. result2 = tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = logits_scaled)
    16. result3 = tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = logits_scaled2)
    17. #使用公式求交叉熵
    18. result4 = -tf.reduce_sum(labels*tf.compat.v1.log(logits_scaled),1)
    19. #标签数据各元素的总和为1
    20. labels2 = [[0.4,0.1,0.5],[0.3,0.6,0.1]]
    21. result5 = tf.nn.softmax_cross_entropy_with_logits(labels = labels2, logits = logits)
    22. #非one-hot标签
    23. labels3 = [2,1]#等价于labels3==[tf.argmax(label,0),tf.argmax(label,1)]
    24. #使用sparse交叉熵函数计算
    25. result6 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = labels3, logits = logits)
    26. print("logits_scaled=",logits_scaled)
    27. print("logits_scaled2=",logits_scaled2)
    28. print("result1=",result1)
    29. print("result2=",result2)
    30. print("result3=",result3)
    31. print("result4=",result4)
    32. print("result5=",result5)
    33. print("result6=",result6)

    总结:

            使用softmax交叉熵函数计算损失值时,如果传入的实参logits是神经网络前向传播完成后的计算结果,则不需要对logits应用softmax算法,因为softmax交叉熵函数会自带计算softmax

            使用sparse交叉熵函数计算损失值时,样本真实值与预测结果不需要one-hot编码,传给参数labels的是标签数数组中元素值为1的位置

            由于交叉熵的损失函数只和分类正确的预测结果有关系,因此交叉熵的计算适用于分类问题上,不适用于回归问题。而均方差(MES)的损失函数由于对每一个输出结果都非常重视,不仅让正确的预测结果变大,还让错误的分类变得平均,更适用于回归问题,不适用于分类问题

            当使用Sigmoid作为激活函数的时候,常用交叉熵损失函数而不是均方差(MES)损失函数,以避免均方差损失函数学习速率降低的问题。

  • 相关阅读:
    【Harmony OS】【JAVA UI】鸿蒙应用如何集成OKHttp网络三方库
    GIS原理篇 线性参照
    懵了,面试官问我Redis怎么测,我哪知道!
    计算机网络
    攻防世界 简单的base编码
    一文掌握Lambda表达式(下)
    Apache Doris 巨大飞跃:存算分离新架构
    干货 | 精准化测试原理简介与实践探索
    RN:Error: /xxx/android/gradlew exited with non-zero code: 1
    【电路参考】缓启动电路
  • 原文地址:https://blog.csdn.net/Victor_Li_/article/details/133778792