• 11-16 周四 简单代码理解FlashAttention 分块计算softmax


    下面的代码对于2*3进行演示

    #!/usr/bin/env python
    # -*- encoding: utf-8 -*-
    import numpy as np
    
    
    # 定义输入数组
    input_array = np.array([[1, 2, 3], [4, 9, 6]])
    
    print("np.e:", np.e)
    print("1/np.e:", 1/np.e)
    
    # 求出每行的最大值
    max_values = np.max(input_array, axis=1, keepdims=True)
    m1 = max_values[0]
    m2 = max_values[1]
    print(f"m1={m1}, m2={m2}")
    
    # 减去每行的最大值
    input_array = input_array - max_values
    # 计算softmax
    exp_values = np.exp(input_array)
    
    f1=exp_values[0]
    f2=exp_values[1]
    max = np.sum(exp_values, axis=1, keepdims=True)
    sum1 = max[0]
    sum2 = max[1]
    
    print(f"f1={f1}\nf2={f2}")
    print(f"sum1={sum1}, sum2={sum2}")
    print("f1/sum1=", f1 / sum1)
    print("f2/sum2=", f2 / sum2)
    softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
    print("sum: ", np.sum(exp_values, axis=1, keepdims=True))
    
    print("基础softmax_output:", softmax_output)
    
    L = np.exp(-6)*sum1 + sum2
    
    print(f"L={L}")
    
    con = np.concatenate((np.exp(-6)*f1, f2), axis=0)
    
    print("con: ", con)
    print("result:", con / L)
    
    
    def softmax(input_array):
        # 求出每行的最大值
        max_values = np.max(input_array, axis=1, keepdims=True)
    
        # 减去每行的最大值
        input_array = input_array - max_values
        # 计算e的指数
        exp_values = np.exp(input_array)
        
        softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
        return softmax_output
    
    
    print("直接计算: ", softmax(np.array([[1, 2, 3, 4, 9, 6]])))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

     上述的代码过程主要是将张量分成了两块进行计算,最后可以看到采用逐步累加的方式得到的结果与逐步运算是相同的。

    进一步优化了程序,让程序可以自由的变大,并且更加灵活

    #!/usr/bin/env python
    # -*- encoding: utf-8 -*-
    import numpy as np
    
    
    # 定义输入数组
    input_array = np.array([[1, 2, 3, 1], [4, 9, 6, 10], [5, 10, 7, 6], [100, 13, 20, 30]])
    
    print("np.e:", np.e)
    print("1/np.e:", 1/np.e)
    
    # # 求出每行的最大值
    # max_values = np.max(input_array, axis=1, keepdims=True)
    # m1 = max_values[0]
    # m2 = max_values[1]
    # print(f"m1={m1}, m2={m2}")
    
    # 最大值
    m = 0;
    L = 0 
    i = 0
    for arr in input_array:
        print(arr)
        if i == 0:
            m = np.max(arr)
            print(f"m更新为{m}")
            print("arr-m:", arr-m)
            temp = np.exp(arr - m)
            
            L = np.sum(temp)
            result = temp 
            i += 1
            print("result:", result)
            continue
            
        
        m2 = np.max(arr)
        print(f"m2={m2}")
        print(f"arr-m2={arr-m2}")
        temp2 = np.exp(arr - m2)
        
        L2 = np.sum(temp2)
        temp2 = temp2
        print(f"temp2={temp2}")
        m_new = m2 if m < m2 else m
        print(f"L2 = {L2}")
        L = np.exp(m - m_new) * L + np.exp(m2 - m_new) * L2
        print(f"L={L}")
        print(f"m-m_new: {m-m_new}, m2-m_new: {m2-m_new}")
        result = np.concatenate((np.exp(m-m_new)*result, np.exp(m2-m_new)*temp2))
        
        print(f"result={result}")
        m = m_new
        print(f"m更新为: {m}")
    
    print(f"结果为: {result/L}")        
    
    
    def softmax(input_array):
        # 求出每行的最大值
        max_values = np.max(input_array, axis=1, keepdims=True)
    
        # 减去每行的最大值
        input_array = input_array - max_values
        # 计算e的指数
        exp_values = np.exp(input_array)
        
        softmax_output = exp_values / np.sum(exp_values, axis=1, keepdims=True)
        return softmax_output
    
    
    print(input_array.reshape(-1))
    
    print("直接计算: ", softmax([input_array.reshape(-1)]))
    print("直接计算[1, 2, 3]: ", softmax(np.array([[1, 2, 3]])))
    print("直接计算[4, 9, 6]: ", softmax(np.array([[4, 9, 6]])))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76

     运算之后可以得到输出如下:

    python softmax.py
    np.e: 2.718281828459045
    1/np.e: 0.36787944117144233
    [1 2 3 1]
    m更新为3
    arr-m: [-2 -1  0 -2]
    result: [ 0.13533528  0.36787944  1.          0.13533528]
    [ 4  9  6 10]
    m2=10
    arr-m2=[-6 -1 -4  0]
    temp2=[ 0.00247875  0.36787944  0.01831564  1.        ]
    L2 = 1.3886738322368428
    L=1.3901679964384732
    m-m_new: -7, m2-m_new: 0
    result=[  1.23409804e-04   3.35462628e-04   9.11881966e-04   1.23409804e-04
       2.47875218e-03   3.67879441e-01   1.83156389e-02   1.00000000e+00]
    m更新为: 10
    [ 5 10  7  6]
    m2=10
    arr-m2=[-5  0 -3 -4]
    temp2=[ 0.00673795  1.          0.04978707  0.01831564]
    L2 = 1.0748406542556834
    L=2.4650086506941564
    m-m_new: 0, m2-m_new: 0
    result=[  1.23409804e-04   3.35462628e-04   9.11881966e-04   1.23409804e-04
       2.47875218e-03   3.67879441e-01   1.83156389e-02   1.00000000e+00
       6.73794700e-03   1.00000000e+00   4.97870684e-02   1.83156389e-02]
    m更新为: 10
    [100  13  20  30]
    m2=100
    arr-m2=[  0 -87 -80 -70]
    temp2=[  1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
    L2 = 1.0
    L=1.0
    m-m_new: -90, m2-m_new: 0
    result=[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
       2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
       5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
       1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
    m更新为: 100
    结果为: [  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
       2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
       5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
       1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
    [  1   2   3   1   4   9   6  10   5  10   7   6 100  13  20  30]
    直接计算:  [[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
        2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
        5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
        1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]]
    直接计算[1, 2, 3]:  [[ 0.09003057  0.24472847  0.66524096]]
    直接计算[4, 9, 6]:  [[ 0.00637746  0.94649912  0.04712342]]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52

     从上述的日志中,可以看到主键累计的计算结果

    结果为: [  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
       2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
       5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
       1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]
    
    • 1
    • 2
    • 3
    • 4

    而直接计算的结果为:

    [[  1.01122149e-43   2.74878501e-43   7.47197234e-43   1.01122149e-43
     2.03109266e-42   3.01440879e-40   1.50078576e-41   8.19401262e-40
     5.52108228e-42   8.19401262e-40   4.07955867e-41   1.50078576e-41
     1.00000000e+00   1.64581143e-38   1.80485139e-35   3.97544974e-31]]
    
    • 1
    • 2
    • 3
    • 4

     因此验证了精确的注意力计算

  • 相关阅读:
    Unity--互动组件(Toggle Group)||Unity--互动组件(Slider)
    谷歌『云开发者速查表』;清华3D人体数据集;商汤『通用视觉框架』公开课;Web3极简入门指南;高效深度学习免费书;前沿论文 | ShowMeAI资讯日报
    数据库-第四/五章 数据库安全性和完整性【期末复习|考研复习】
    Linux系统编程——线程的学习
    Tomcat 源码分析 (Tomcat的Session管理) (十一)
    Activiti回退与跳转节点
    vue2vue3生命周期详解
    Flink1.14 Source概念入门讲解与源码解析
    APP性能---用adb命令测试Android中APP的FPS
    el-admin 选择年份查询范围
  • 原文地址:https://blog.csdn.net/lk142500/article/details/134443958