• 白话transformer(三):Q K V矩阵代码演示


    在前面文章讲解了QKV矩阵的原理,属于比较主观的解释,下面用简单的代码再过一遍加深下印象。

    B站视频

    白话transformer(三)

    1、生成数据

    我们呢就使用一个句子来做一个测试,

    text1 = "我喜欢的水果是橙子和苹果"
    text2 = "相比苹果我更加喜欢国产的华为"
    
    • 1
    • 2

    比如我们有两个句子,里面都有苹果这个词。我们用text1来走下流程

    1.1 创建词嵌入

    我们使用spacy进行词嵌入生成,代码很简单

    nlp = spacy.load('zh_core_web_sm')
    doc = nlp(text1)
    
    • 1
    • 2

    我们为了简单一点只取前10个维度,实际上spacy默认的词嵌入维度是很高的,我们只是用前十个来过一下流程。

    emd_dim = 10
    
    dics = {}
    for token in doc:
        dics[token.text] = token.vector[:emd_dim]
    X = pd.DataFrame(dics)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在这里插入图片描述
    这样我们就得到了第一个句子中所有词的embedding表示

    2、初始化 W q W_q Wq, W k W_k Wk, W v W_v Wv

    具体的内容可以查看之前的文章Bert基础(一)–自注意力机制

    为了创建查询矩阵、键矩阵和值矩阵,我们需要先创建另外三个权重矩阵,分别为 W Q 、 W K 、 W V W^Q 、W^K、W^V WQWKWV。用矩阵X分别乘以矩阵 W Q 、 W K 、 W V W^Q 、W^K、W^V WQWKWV,就可以依次创建出查询矩阵Q、键矩阵K和值矩阵V。

    d_k = 6       # QKV向量的维度
    
    • 1

    A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V
    d_k是指公式中的d_k

    Wq = np.random.randn(emd_dim, d_k)
    
    • 1

    在这里插入图片描述
    Wq矩阵的格式,就是10*6

    • 10:是指词嵌入的维度
    • 6:d_k,Q的维度

    Wk, Wv,同样

    3、计算QKV

    Q = X * Wq

    np.dot(X.T, Wq)
    
    • 1

    在这里插入图片描述
    这样就得到了查询矩阵Q,Q其实可以理解为每个词需要查询的内容。

    同样可以计算K和V矩阵

    4、相似矩阵

    计算公式为:
    X W Q ∗ ( W K X ) T XW^Q *(W^KX )^T XWQ(WKX)T

    其实就是我们计算好的Q和K
    Q K T Q K^T QKT
    直接点乘就可以得到每个词和每个词的相似性:
    在这里插入图片描述

    5、点积缩放

    Q@K.T/ np.sqrt(d_k)
    
    • 1

    在这里插入图片描述

    6、Soft Max

    我们自己遍历计算一下即可

    # 计算Softmax
    for i in range(len(df_QK)):
        exp_v = np.exp(df_QK.iloc[i])
        softmax = exp_v / np.sum(exp_v)
        df_QK.iloc[i] = softmax
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在这里插入图片描述
    现在就得到了最后的相似性矩阵

    7、attention

    A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

    根据公示直接将前面计算的结构点乘V
    在这里插入图片描述

  • 相关阅读:
    pandas常用操作
    JAVA:实现二个数字的通用根算法(附完整源码)
    JDK更换版本配置不生效问题
    《Effective C++》条款14
    NIO基础知识
    8086汇编笔记
    为什么要做数据可视化系统
    2023-09-30:用go语言,给你一个整数数组 nums 和一个整数 k 。 nums 仅包含 0 和 1, 每一次移动,你可以选择 相邻 两个数字并将它们交换。 请你返回使 nums 中包含 k
    享元模式
    Vue2项目知识点总结-尚品汇
  • 原文地址:https://blog.csdn.net/Andy_shenzl/article/details/136716637