• 论文解析[6] Transformer: Attention Is All You Need


    发表年份:2017
    论文地址:https://arxiv.org/abs/1706.03762
    代码地址:https://github.com/tensorflow/tensor2tensor

    摘要

    主要的序列转录模型是基于复杂的循环或卷积神经网络,包括一个编码器和解码器。表现最好的模型也通过一个注意力模块连接编码器和解码器。

    我们提出了一个新的简单网络结构,transformer,仅仅基于注意力机制,完全摒弃了循环和卷积。

    3 模型结构

    在这里插入图片描述

    3.1 编码器和解码器

    编码器:包括6个一样的层,每层有两个子层,第一个子层是一个多头自注意力模块,第二个子层是一个全连接网络。每个子层的输出为LayerNorm(x + Sublayer(x))。为了方便残差连接,所有的子层和嵌入层输出的维度都是512维。

    解码器:包括6个一样的层,每层有三个子层,相比于编码器的两个子层多出了一个带掩码的多头注意力模块。带掩码是为了保证在t时间看不到之后的输入,从而使训练和预测时的情况一致。

    3.2 注意力

    注意力函数是将query和一些key-value对映射为输出。输出是value的加权之和,是根据query和对应key的相似度来进行计算的。

    3.2.1 Scaled Dot-Product Attention

    在这里插入图片描述

    输入包括三个向量query( d k d_k dk)、key( d k d_k dk)、value( d v d_v dv),计算Q和所有K的点乘,再除以( d k \sqrt{d_k} dk ),最后再用softmax计算权重

    将多组的query、key、value向量组成矩阵Q、K、V
    在这里插入图片描述

    3.2.2 多头注意力

    在这里插入图片描述

    线性层表示对矩阵进行投影到比较低的维度,再进行上面的Scaled Dot-Product Attention运算

    随后将不同的头(h个)的结果连接起来,再投影到一个 W O W^O WO维的空间,

    对每一个头的操作即为把Q、K、V通过对应的W矩阵投影到 d v d_v dv维,再做注意力函数

    在这里插入图片描述

    3.2.3 模型中注意力的使用

    • 在“编码器-解码器注意力”层中,query来自于上一个解码器层,key和value来自于编码器的输出。这使得解码器中每个位置都能访问输入序列中所有位置。
    • 编码器包括自注意力层,在自注意力层中,所有的key、value、query都来自同一个输入,即编码器的上一层。解码器中每个位置都能访问输入序列中所有位置。
    • 相似地,解码器中的自注意力层允许访问解码器中的所有位置,直到当前位置。为了保证自回归性,使用了带掩码的注意力

    3.3 Position-wise Feed-Forward Networks

    在这里插入图片描述

    Position-wise是说对每个词分别进行计算,实际上类似于一个MLP(多层感知机)

    两个矩阵表示对词进行投影,W1投影到2048维,W2投影回512维

    Transformer和RNN的区别:
    1)RNN是每次使用上一个MLP的输出和当前的输入信息作为输入。
    2)Transformer是使用attention来提取全局的信息,后面再使用MLP进行提取每个词

    3.4 词嵌入

    对于一个词,映射为一个长为512的向量来表示。

    3.5 位置编码

    因为attention没有时序信息(输入只是value的一个加权和),所以需要把时序信息加进来

    transformer的做法是在输入里面给词加上位置信息。

    6 结果

    在这里插入图片描述

    7 结论

    在本工作中,提出了Transformer,第一个完全基于注意力的序列转录模型,使用多头自注意力替换了编码器-解码器结构中最常用的循环层。

    对于翻译任务,Transformer的训练可以明显快于基于循环或卷积层的结构。在WMT2014英语翻译德语和WMT2014英语翻译法语的翻译任务中,取得了sota。在之前的任务中,我们最好的模型甚至超过了先前所有的集合。

  • 相关阅读:
    Proteus8仿真:51单片机使用ULN2003A控制步进电机
    【Java线程池ThreadPool(创建线程的第三种方式)】
    7.5模拟赛总结
    java基础·小白入门(一)
    隆云通空气温湿、PM2.5传感器
    MySQL 学习记录 2
    TLS/SSL通信基于NodeJS16
    长期稳定的项目—steam搬砖
    做社交媒体营销应该注意些什么?Shopline卖家的成功秘笈在这里!
    Qt通过正则表达式筛选出字符串中的手机号
  • 原文地址:https://blog.csdn.net/weixin_43772166/article/details/127709955