码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【深度学习】4-梯度确认时遇bug:写了个糟糕的softmax函数


    🚩 前言

    活动地址:CSDN21天学习挑战赛
    🚀 博主主页:清风莫追
    🌊 希望和大家一起加油,一起进步!

    今天学习《深度学习入门:基于python的理论与实践》时,被梯度确认的问题卡了很久。具体过程就不赘述,最终发现是我写的 softmax 函数的问题,导致的数值微分与反向传播求得的梯度总对不上。
    softmax 问题在于我这里没有考虑好不同维度数据的情况。


    文章目录

    • 🚩 前言
    • 1. softmax函数代码
    • 2. 不能正确处理批量样本
    • 3. 解决方案


    1. softmax函数代码

    import numpy as np
    
    def softmax(a): 
        a -= np.max(a)
        exp_a = np.exp(a)
        return exp_a / np.sum(exp_a)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    2. 不能正确处理批量样本

    我们对比一下处理单个样本和批量样本的情况:

    a = np.array([1, 2, 3])
    b = np.array([ [1, 2, 3], [1, 2, 3]])
    softmax(a), softmax(b)
    
    
    • 1
    • 2
    • 3
    • 4
    输出:
    (array([0.09003057, 0.24472847, 0.66524096]),
     array([[0.04501529, 0.12236424, 0.33262048],
            [0.04501529, 0.12236424, 0.33262048]]))
    
    • 1
    • 2
    • 3
    • 4

    容易发现,对应的值,后者的结果都是前者的一半。问题就在np.sum(exp_a)这里:
    在批量数据(矩阵)情况下,我们本意是每个样本求得一个和值,即按行求和。它却将所有数据相加而仅得到一个和值。

    再验证一下:

    a = np.array([1, 2, 3])
    b = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
    softmax(a), softmax(b)
    
    • 1
    • 2
    • 3
    输出:
    (array([0.09003057, 0.24472847, 0.66524096]),
     array([[0.03001019, 0.08157616, 0.22174699],
            [0.03001019, 0.08157616, 0.22174699],
            [0.03001019, 0.08157616, 0.22174699]]))
    
    • 1
    • 2
    • 3
    • 4
    • 5

    可以看到,样本数增加为 3 个时,后者的输出也相应地变成了前者的三分之一。此时前者输出的 3 个数和为 1 ,而后者是输出的 9 个数和为 1。

    现在单样本和批量的输出之间还具有倍数关系,如果每个样本是数据不同,那么将会更加乱套:

    a = np.array([1, 2, 3])
    b = np.array([[1, 2, 3], [4, 5, 6]])
    softmax(a), softmax(b)
    
    • 1
    • 2
    • 3
    输出:
    (array([0.09003057, 0.24472847, 0.66524096]),
     array([[0.00426978, 0.01160646, 0.03154963],
            [0.08576079, 0.23312201, 0.63369132]]))
    
    • 1
    • 2
    • 3
    • 4

    3. 解决方案

    分情况处理即可:

    def softmax(x):
        if x.ndim == 2:
            x = x.T
            x = x - np.max(x, axis=0)
            y = np.exp(x) / np.sum(np.exp(x), axis=0)
            return y.T 
    
        x = x - np.max(x) 
        return np.exp(x) / np.sum(np.exp(x))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    注:此代码来自前言提到的书本
    我们之前分析了,其实就是求和函数的问题,因此我尝试着能不能写得更加简洁和统一些。

    但是我没有成功,因为会遇到一些细节处理上的问题,例如批量情况下:最大值 max 也应当是按行求得;np.sum得到的是一维数组,为了最后相除时能够进行广播,进行转置是有必要的。

    不过修改代码的尝试倒是加深了我对它的理解。


    感谢阅读

  • 相关阅读:
    【C/C++笔试练习】初始化列表、构造函数、析构函数、两种排序方法、求最小公倍数
    Set接口的实现类---HashSet
    Python 处理 PDF —— PyMuPDF 的安装与使用
    页面滚动那些事儿(滚动条样式自定义,隐藏滚动条,scrollIntoView遇到头部fixed定位滚动被遮挡、vue-scrollto插件的应用)
    【JavaScript】判断变量类型是否是字符串
    《实用软件工程》课程教学大纲(Practicality Software Engineering)
    Cesium 问题:加载 geojson 文件后使用 remove 方法移除,但浏览器内存会持续增长并为得到释放直到浏览器崩掉
    14-vue项目搭建.md
    [ruby on rails] pg触发器trigger的使用
    开源组件 | 一款好用的小程序生成图片库
  • 原文地址:https://blog.csdn.net/m0_63238256/article/details/126411411
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号