• Transformer原理详解


    Transformer原理详解

    一、前言

    在学习Transformer之前,需要掌握Attention机制、Encoder-Decoder、Self-Attention机制的基础知识,可以参考本人的博客注意力机制Attention详解,在掌握了本文内容之后,可以阅读论文 《Attention is all you need》 ,这是谷歌团队在2017年发表的论文,也是首次提出Transformer,本人在阅读后的笔记:Attention Is All You Need-笔记,有不足之处欢迎大家评论指出【🤝🤝】

    二、Transformer整体架构

    在这里插入图片描述
    从Transformer整体架构中我们可以看到整个模型分成左右两个部分,左侧为Encoder,右侧为Decoder。Encoder与Decoder各自都堆叠了N层,N是可以自定义的数值,在《Attention is all you need》这一论文中,作者定义的数值为6。那么接下来,我将会按照整个架构,从左向右进行分析讲解。

    三、Encoder部分

    在这里插入图片描述
    从图中可以看出,整个Encoder模块被划分为三个小部分:1.输入部分;2.多头注意力机制部分;3.前馈神经网络部分。

    (一)输入部分Inputs

    1.Embedding
    对于要处理的一串文本,我们要让其能够被计算机处理,需要将其转变为词向量,方法有最简单的one-hot,或者有名的Word2Vec等,甚至可以随机初始化词向量。具体的实现细节参考本人博客:Word2vec词向量文本分析详解,经过Embedding后,文本中的每一个字就被转变为一个向量,能够在计算机中表示。《Attention is all you need》这一论文中,作者采用的是512维词向量表示,也就是说,每一个字被一串长度为512的字向量表示。
    2.位置编码Positional Encoding
    首先为什么需要位置编码?我们对于RNN在处理文本时,由于天然的顺序输入,顺序处理,当前输出要等上一步输出处理完后才能进行,因此不会造成文本的字词在顺序上或者先后关系出现问题。具体细节参考本人博客:详细讲解RNN+LSTM+Tree_LSTM(Tree-Long Short Term Memory)基于树状长短期记忆网络。但对于Transformer来说,由于其在处理时是并行执行,虽然加快了速度,但是忽略了词字之间的前后关系或者先后顺序。同时Transformer基于Self-Attention机制,而self-attention不能获取字词的位置信息,即使打乱一句话中字词的位置,每个词还是能与其他词之间计算出attention值,因此我们需要为每一个词向量添加位置编码。
    那么位置编码如何计算?
    这是《Attention is all you need》论文中给出的计算公式。
    在这里插入图片描述
    pos表示字词的位置,2i表示在512维词向量中的偶数位置,2i+1表示在512维词向量中的奇数位置;公式表达的含义是在偶数的位置使用sin函数计算,在奇数的位置使用cos函数计算,如下图。
    在这里插入图片描述
    得到512维的位置编码后,我们将512维的位置编码与512维的词向量相加,得到最终的512维词向量作为最终的Transformer输入。
    在这里插入图片描述
    那么为什么位置嵌入能够发挥作用?
    通过参考如何理解Transformer论文中的positional encoding,和三角函数有什么关系?这篇博客,首先根据三角函数的性质
    在这里插入图片描述
    我们可以将公式推导为:
    在这里插入图片描述
    可以看出,对于pos+k位置的位置编码向量某一维2i或2i+1而言,可以将其表示为,pos位置与k位置的位置编码向量的2i与2i+1维的线性组合,这也意味着绝对位置向量中蕴含了相对位置信息。其实也很通俗:我们自己通过公式计算得到的是绝对位置信息,而用pos位置与k位置的位置编码向量的2i与2i+1维的线性组合得到的是相对位置信息。

    (二)多头注意力机制Multi-Head Attention
    1.回顾Attention

    首先回顾Attention机制,如何获取Q,K,V向量。例如有两个序列“Thinking”,“Machines”,图中X1,X2分别对应于“Thinking”,“Machines”经过添加位置编码后的词向量。X1,X2分别与三个权重矩阵WQ,WK,WV进行矩阵乘法运算,得到q1,k1,v1与q2,k2,v2向量。
    在这里插入图片描述
    得到q与k之后,将q与k做内积,表示词语之间的相互关系,两者内积越大表示越相关;
    接下来除以根号下dk,可以保持方差为1,均值为0(具体细节参考论文笔记);
    接着通过softmax函数归一化,并与v向量相乘得到z。
    在这里插入图片描述
    但是在实际计算过程中,并不是一个一个计算,作者在论文中提到会通过一次矩阵乘法得出结果,方便并行
    在这里插入图片描述
    也就是说每一个序列会以矩阵的形式进行输入,有几个序列X矩阵就有几行,在上面的案例中,有两个序列“Thinking”,“Machines”,那么X矩阵第一行表示“Thinking”,第二行表示“Machines”,X矩阵是2×512维。权重矩阵WQ,WK,WV都是512×64维的【这里的64 = 512 / 8 (8:表示论文中作者定义多头个数)】,因此得到的Query、Keys、Values三个矩阵都是2×64维的。
    得到Query、Keys、Values三个矩阵后,计算Attention Value,步骤如下:

    1. 将Query(2×64)矩阵与Keys矩阵的转置(64×2)相乘,【作者论文中采用的是两个矩阵点积,当然还可以采用其他方式计算相似度得分(余弦相似度,MLP等)】得到相关性得分score(2×2);
    2. 将相关性得分score / 根号下(dk),dk=64是矩阵Keys的维度,这样可以保证梯度稳定;
    3. 进行softmax归一化,将相关性评分映射到0-1之间,得到一个2×2的概率分布矩阵α
    4. 将概率分布矩阵α与Values矩阵进行点积运算,得到(2×2) ⊙ (2×64) = 2×64 维的句子Z。
    2.Multi-Attention多头注意力机制

    那么什么是多头注意力?作者在论文中介绍到:同时使用多组权重矩阵WQ,WK,WV得到多组Query、Keys、Values矩阵。 作者定义的头数Multi-head为 8
    在这里插入图片描述
    经过这样的操作可以让Transformer捕捉到更多子空间的信息,那么每一组最终都能得到一个Z矩阵,最终将8个Z矩阵进行拼接。
    在这里插入图片描述

    3.残差

    下图中X1,X2词向量经过添加位置编码后(黄色的X1,X2)输入到自注意力层,输出Z1,Z2;接着将添加位置编码后(黄色的X1,X2)组成的矩阵与Z1,Z2组成的矩阵对位相加;将对位相加后的结果进行LayerNormalize操作。
    在这里插入图片描述
    为啥要这么做?
    这是一张描述残差的经典图。x作为输入,经过两层网络weight layer,我们将这两层网络看成一个F(x),那么如果没有残差网络,我们可以直接输出F(x)。但是有了残差网络,就相当于将输入的x原封不动的拿到输出位置与F(x)进行相加操作,因此输入为F(x)+x
    在这里插入图片描述
    那么残差网络的作用是啥?下面这张图和上面的其实是一一对应的关系,A相当于输入x,B,C为两层网络,D为输出。
    在这里插入图片描述
    我们可以针对网络进行链式求导,L表示损失函数,XAout表示输入x,C(B(XAout))相当于F(x),我们可以发现最终的结果中包含1加上后面的连乘。在一般情况下,连乘会导致梯度消失,但是在残差网络结构中,最终的连乘还要加上1就确保了最终的梯度不会为0 ,这样就解决了梯度消失的问题。这也是NLP领域中,使用残差网络我们可以让模型结构变深。
    在这里插入图片描述

    4.Layer Normalization

    为什么在Transformer中使用Layer Narmalization而不是Batch Normalization?
    Batch Normalization结构
    在这里插入图片描述
    假设图中的x1,x2,x3,…代表的是不同的样本:小明,小红,小兰…
    每一行代表一个特征:第一行是身高,第二行是体重,第三行是成绩……
    在Batch Normalization中,我们是针对所有样本的同一个特征做Batch Normalization,也就是说,我们针对所有人的身高做Batch Normalization、对所有人的体重做Batch Normalization、对所有人的成绩做Batch Normalization。
    Batch Normalization的假设是采用整个样本的数据的均值和方差来模拟全部数据的均值和方差。 也就是说,假如班级中有60名同学,batch_size = 10时,是用这10个人体重、身高的均值方差来模拟全班60名同学的体重身高的均值和方差。因此它体现了Batch Normalization第一个缺点:当batch_size较小时,效果太差;第二个缺点是:在RNN中效果差。
    其实总结起来,就是因为Batch Normalization效果差,所以不用。

    那为啥用Layer Normalization?
    在这里插入图片描述

    按照上面Batch Normalization的原理,针对不同样本的同一特征做归一化,在这个案例中,“我”和“今”表示的是不同的语义信息,显然不能直接类比。
    而Layer Normalization是针对同一个样本的不同特征做归一化,在这个案例中,对“今天天气真不错”的每一个词做均值和方差。

    (三)前馈神经网络Feed Forward

    前馈全连接层:在Transformer中前馈全连接层就是具有两层线性层的全连接网络。
    作用:考虑到注意力机制可能对复杂过程的拟合程度不够,通过增加两层网络来增强模型的能力。
    在这里插入图片描述
    在这里插入图片描述
    公式中的x表示多头注意力机制输出的Z(2×64),假设W1是64×1024,W2是1024×64,那么经过FFN(x) = (2×64)⊙(64×1024)⊙(1024×64) = 2×64,最终的维度并没有变。这两层的全连接层作用是将输入的Z映射到高维 (2×64)⊙(64×1024)=(2×1024),最后通过ReLu函数变为原来的维度。
    最后再经过一次Add&Normalization,输入到下一个Encoder中,经过6次,就可以输入到Decoder中啦!
    以上就是Encoder部分的全部内容啦!!!

    四、Decoder部分

    首先来看看Decoder模块的整体架构图吧!
    在这里插入图片描述
    同样,Decoder也是堆叠6层,虽然Decoder的模块比Encoder部分多,但很多是重复的因此我们主要分析Decoder独有的部分。

    (一)输入

    这里论文中的结构图采用的是Outputs进行表示,其实是上一部分Encoder的输出作为Decoder的输入。
    Decoder的输入分为两类:
    一类是训练时的输入,一种是预测时的输入。
    训练时的输入就是已经对准备好对应的target数据。例如翻译任务,Encoder输入"Tom chase Jerry",Decoder输入"汤姆追逐杰瑞"。
    预测时的输入,一开始输入的是起始符,然后每次输入是上一时刻Transformer的输出。例如输入"“,输出"汤姆”;输入"汤姆",输出"汤姆追逐";输入"汤姆追逐",输出"汤姆追逐杰瑞";输入"汤姆追逐杰瑞",输出"汤姆追逐杰瑞"结束。

    (二)Masked Multi-Head Attention

    由于解码器采用自回归auto-regressive,即在过去时刻的输出作为当前时刻的输入,也就是说在预测时无法看到之后的输入输出,但是在注意力机制当中,可以看到完整的输入(每一个词都要和其他词做点积,计算相关性),为了避免这种情况的发生,在解码器训练时,在预测t时刻的输出时,不应该能看到t时刻以后的输入。做法是:采用带掩码的Masked注意力机制,从而保证在t时刻无法看到t时刻以后的输入,保证训练和预测时的行为一致性。

    (三)Decoder部分的 Multi-Head Attention

    这一部分的多头注意力机制它的输入Quer来自于Masked Multi-Head Attention的输出,Keys和Values来自于Encoder中最后一层的输出。
    在这里插入图片描述
    Encoders也就是所有的Encoder的输出要和所有的Decoder进行交互,具体交互如下:
    在这里插入图片描述
    Encoder生成Keys,Value矩阵,Decoder生成的是Query矩阵,三者进行交互,生成多头注意力机制。

    最后两个部分:前馈神经网络和Decoder部分相同;输出部分,首先进行一次Linear线性变换,然后Softmax函数输出的概率分布矩阵,最后输出概率最大的对应的单词作为预测输出。

    在这里插入图片描述
    最后看看图回顾以下吧!,以上就是Transformer整体架构的分析,如有不足或者错误欢迎指正🤝🤝
    接下来将会继续分享NLP领域近几年最著名的语言模型Bert,欢迎关注O(∩_∩)O

  • 相关阅读:
    PowerBI工作区连接Log Aanlytics
    Python获取cookie用法介绍
    HTML的学习 Day01
    SparkSQL系列-7、自定义UDF函数?
    Hive UDF array_struct_sort 对Array<Struct>进行排序
    【虹科干货】轻松简化数据库客户端工作,除了Proxy还有谁?
    企业微信获取客户群里用户的unionid;企业微信获取客户详情
    【力扣 - 只出现一次的数字】
    零基础5分钟上手亚马逊云科技AWS核心云开发/云架构知识 - 成本分析篇
    Tessent scan & ATPG(8) Debug low test coverage(低测试覆盖率的原因及debug方法)
  • 原文地址:https://blog.csdn.net/qq_45556665/article/details/127466606