• Attention Is All You Need原理与代码详细解读



    前言

    目前,我研究大模型相关知识,常用到transformer结构,我想到NLP领域开篇之作Attention is all you need论文,论文实际提出transform结构,可与CNN并驾齐驱的结构,该结构利用Q/K/V模式整合全局信息,与CNN提取局部信息有所差别。介于此,我将一年前博客园更新笔记迁入该博客中,本文将介绍transform原理,也根据源码解读,深入介绍transforme经典典结构,并附有代码。


    论文链接:点击这里

    一、Transformer结构的原理

    该部分主要介绍Attention is all you need 结构、模块、公式。暂时不介绍什么Q K V 什么Attention 什么编解码等,后面我将会根据代码解读介绍,让读者更容易理解。

    1、Transform结构

    Transformer由且仅由Attention和Feed Forward Neural Network(也称FFN)组成,其中Attention包含self Attention与Mutil-Head Attention,如下图:
    示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。
    注:模型一般可有encode与decode组成,encode负责特征编码,decode负责解码。目前,也有论文不使用解码器decode,如swin-transform。

    2、位置编码公式

    位置编码公式(还有很多其它公式,该论文使用此公式),如下:

    在这里插入图片描述

    3、transformer公式

    在这里插入图片描述

    4、FFN结构

    FFN是由nn.Linear线性和激活函数构成,后面代码详细说明。

    二、Encode模块代码解读

    1、编码数据

    编码输入数据介绍:
    enc_input = [
    [1, 3, 4, 1, 2, 3],
    [1, 3, 4, 1, 2, 3],
    [1, 3, 4, 1, 2, 3],
    [1, 3, 4, 1, 2, 3]]
    编码使用输入数据,为4x6行,表示4个句子,每个句子有6个单词,包含标点符号。
    注:至于文本如何表示数字,可参考这里

    2、文本Embedding编码

    文本嵌入embedding:

    self.src_emb = nn.Embedding(vocab_size, d_model) # d_model=128
    
    • 1

    vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999)

    d_model:嵌入向量的维度,即用多少维来表示一个词或符号

    nn.Embedding()函数可使用torch调用,建议读者百度了解其功能。

    随后可将输入x=enc_input,可将enc_outputs则表示嵌入成功,维度为[4,6,128]分别表示batch为4,词为6,用128维度描述词6

    x = self.src_emb(x)  # 词嵌入
    
    • 1

    3、位置position编码

    位置编码,使用上面公式嵌入,我将不再介绍,其代码如下:

     pe = torch.zeros(max_len, d_model)
             position = torch.arange(0., max_len).unsqueeze(1)
             div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶数列
             pe[:, 0::2] = torch.sin(position * div_term) # 奇数列
             pe[:, 1::2] = torch.cos(position * div_term)
             pe = pe.unsqueeze(0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    将编码进行位置编码后,位置为[1,6,128]+输入编码的[4,6,128],相当于句子已经结合了位置编码信息,作为新新的输入,代码如下:

    x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  #torch.autograd.Variable 表示有梯度的张量变量
    
    • 1

    4、Attention编码

    在介绍此之前,先普及一个知识,若X与Y相等,则为self attention 否则为cross-attention,因为解码时候X!=Y.
    在这里插入图片描述

    获取Q K V 代码,实际是一个线性变化,将以上输入x变成[4,6,512],然后通过head个数8与对应dv,dk将512拆分[8,64],随后移维度位置,变成[4,8,6,64]

     self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用线性卷积
     self.WK = nn.Linear(d_model, d_k * n_heads)
     self.WV = nn.Linear(d_model, d_v * n_heads)
    
    • 1
    • 2
    • 3

    变化后的q k v

     q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 线性卷积后再分组实现head功能
     k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
     v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
     attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 编导对应的头
    
    • 1
    • 2
    • 3
    • 4

    随后通过以上self公式,将其编码计算

    scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
    attn = nn.Softmax(dim=-1)(scores)
    context = torch.matmul(attn, V)
    
    
    • 1
    • 2
    • 3
    • 4

    以上编码将是encode编码得到结果,我们将得到结果进行还原:

    context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
    output = self.linear(context)  # 通过线性又将其变成原来模样维度
    layer_norm(output + Q)  # 这里加Q 实际是对Q寻找
    
    • 1
    • 2
    • 3

    以上将重新得到新的输入x,维度为[4,6,128]

    5、FFN编码

    将以上的输出维度为[4,6,128]进行FFN层变化,实际类似线性残差网络变化,得到最终输出

      class PoswiseFeedForwardNet(nn.Module):
      
          def __init__(self, d_model, d_ff):
              super(PoswiseFeedForwardNet, self).__init__()
              self.l1 = nn.Linear(d_model, d_ff)
              self.l2 = nn.Linear(d_ff, d_model)
      
              self.relu = GELU()
              self.layer_norm = nn.LayerNorm(d_model)
     
         def forward(self, inputs):
             residual = inputs
             output = self.l1(inputs)  # 一层线性卷积
             output = self.relu(output)
             output = self.l2(output)  # 一层线性卷积
             return self.layer_norm(output + residual)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    重复以上顺序编码,即将得到经过FFN变化的输出x,维度为[4,6,128],将其重复步骤③-④,因其编码为6个,可重复5个便是完成相应的编码模块。

    三、Decode模块代码解读

    1、编码数据

    解码输入数据介绍,包含以下数据输入dec_input、enc_input的输入与解码后输出的数据,维度为[4,6,128],而dec_input输入如下:

    dec_input = [
    [1, 0, 0, 0, 0, 0],
    [1, 3, 0, 0, 0, 0],
    [1, 3, 4, 0, 0, 0],
    [1, 3, 4, 1, 0, 0]]

    2、文本Embedding与位置编码

    dec_input的Embedding与位置编码,因其与encode的实现方法一致,只需将enc_input使用dec_input取代,得到dec_outputs,因此这里将不在介绍。

    3、mask编码

    整体编码,代码如下:

      def get_attn_pad_mask(seq_q, seq_k, pad_index):
         batch_size, len_q = seq_q.size()
         batch_size, len_k = seq_k.size()
         pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
         pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
         return pad_attn_mask.expand(batch_size, len_q, len_k)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    以上代码实际是将dec_input进行处理,实际变成以下数据:

    [[0, 1, 1, 1, 1, 1],
    [0, 0, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 1],
    [0, 0, 0, 0, 1, 1]]

    将其增添维度为[4,1,6],并将其扩张为[4,6,6]

    局部代码编写,实际为上三角矩阵:

    [[0. 1. 1. 1. 1. 1.]
    [0. 0. 1. 1. 1. 1.]
    [0. 0. 0. 1. 1. 1.]
    [0. 0. 0. 0. 1. 1.]
    [0. 0. 0. 0. 0. 1.]
    [0. 0. 0. 0. 0. 0.]]
    将以上数据添加维度为[1,6,6],在将扩展变成[4,6,6]
    关于整体mask与局部mask编码,我的理解是整体信息为语句4个词6个,根据解码输入编码整体信息,而局部编码是基于一个语句6*6编码信息,将其扩张重复到4个语句,
    使其mask获得整体信息与局部信息。

             dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)  # 整体编码的mask
             dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
             dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)  # torch.gt(a,b) a>b 则为1否则为0
             dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5

    最终将mask整合,获取dec_self_attn_mask信息,同理dec_enc_attn_mask(维度为解码编码词维度)采用dec_self_attn_mask的第一步便可获取。

    4、Attention编码

    编码输入self-Attention,包含2部分,self Attention与cross Attention。

    self attention

    解码输入dec_outputs进行self.Attention:
    实际使用以上Q K V公式,具体实现和编码实现方法一致,唯一不同是在Q*K^T会使用解码maskdec_self_attn_mask,其重要代码为scores.masked_fill_(attn_mask, -1e9),代码如下:

      class ScaledDotProductAttention(nn.Module):
      
          def __init__(self, d_k, device):
              super(ScaledDotProductAttention, self).__init__()
              self.device = device
              self.d_k = d_k
      
          def forward(self, Q, K, V, attn_mask):
              scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
              attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
              attn_mask = attn_mask.to(self.device)
              scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
              attn = nn.Softmax(dim=-1)(scores)
              context = torch.matmul(attn, V)
              return context, attn
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    以上代码将执行以下代码:

    context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
                                                                                attn_mask=attn_mask)
    context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
    output = self.linear(context)  # 通过线性又将其变成原来模样维度
    dec_outputs = self.layer_norm(output + Q)  # 这里加Q 实际是对Q寻找
    
    • 1
    • 2
    • 3
    • 4
    • 5

    到此为止已经完成了解码输入的self-attention模块,输出为dec_outputs实际除了增加mask编码调整Q*K^T以外,其它完全相同。

    cross attention

    编码输出dec_outputs进行Cross Attention:

    dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) 
    
    • 1

    重点说明enc_outputs来源编码结果,是一直不变的,以上为Cross Attention 过程,以上代码除了Q来源dec_outputs,K V 来源编码输出enc_outputs以外,即论文所说X与Y不等得到的Q K V称为Cross Attention。
    实际以上代码与执行解码self-Attention方法完全一致,仅仅mask更改上文提供的方法,得到输出结果为dec_outputs,因此这里将不在解释了。

    5、FFN编码

    该部分编码与encode的FFN一样,我将不在解释。

    重复步骤上面4与5为n次,便实现解码过程。

    四、源码附件(源码有注释)

    最后,我给出attention is all you need的所有代码,只需简单环境便可使用,整体实现代码如下:

    import json
    import math
    import torch
    import torchvision
    import torch.nn as nn
    import numpy as np
    from pdb import set_trace
    
    from torch.autograd import Variable
    
    
    def get_attn_pad_mask(seq_q, seq_k, pad_index):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(pad_index).unsqueeze(1)
        pad_attn_mask = torch.as_tensor(pad_attn_mask, dtype=torch.int)
        return pad_attn_mask.expand(batch_size, len_q, len_k)
    
    
    def get_attn_subsequent_mask(seq):
        attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
        subsequent_mask = np.triu(np.ones(attn_shape), k=1)
        subsequent_mask = torch.from_numpy(subsequent_mask).int()
        return subsequent_mask
    
    
    class GELU(nn.Module):
    
        def forward(self, x):
            return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    
    
    class PositionalEncoding(nn.Module):
        "Implement the PE function."
    
        def __init__(self, d_model, dropout, max_len=5000):  #
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(p=dropout)
    
            # Compute the positional encodings once in log space.
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0., max_len).unsqueeze(1)
            div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))  # 偶数列
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)  # 将变量pe保存到内存中,不计算梯度
    
        def forward(self, x):
            x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)  # torch.autograd.Variable 表示有梯度的张量变量
            return self.dropout(x)
    
    
    class ScaledDotProductAttention(nn.Module):
    
        def __init__(self, d_k, device):
            super(ScaledDotProductAttention, self).__init__()
            self.device = device
            self.d_k = d_k
    
        def forward(self, Q, K, V, attn_mask):
            scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
            attn_mask = torch.as_tensor(attn_mask, dtype=torch.bool)
            attn_mask = attn_mask.to(self.device)
            scores.masked_fill_(attn_mask, -1e9)  # it is true give -1e9
            attn = nn.Softmax(dim=-1)(scores)
            context = torch.matmul(attn, V)
            return context, attn
    
    
    class MultiHeadAttention(nn.Module):
    
        def __init__(self, d_model, d_k, d_v, n_heads, device):
            super(MultiHeadAttention, self).__init__()
            self.WQ = nn.Linear(d_model, d_k * n_heads)  # 利用线性卷积
            self.WK = nn.Linear(d_model, d_k * n_heads)
            self.WV = nn.Linear(d_model, d_v * n_heads)
    
            self.linear = nn.Linear(n_heads * d_v, d_model)
    
            self.layer_norm = nn.LayerNorm(d_model)
            self.device = device
    
            self.d_model = d_model
            self.d_k = d_k
            self.d_v = d_v
            self.n_heads = n_heads
    
        def forward(self, Q, K, V, attn_mask):
            batch_size = Q.shape[0]
            q_s = self.WQ(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)  # 线性卷积后再分组实现head功能
            k_s = self.WK(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
            v_s = self.WV(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
    
            attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)  # 编导对应的头
            context, attn = ScaledDotProductAttention(d_k=self.d_k, device=self.device)(Q=q_s, K=k_s, V=v_s,
                                                                                        attn_mask=attn_mask)
            context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)  # 将其还原
            output = self.linear(context)  # 通过线性又将其变成原来模样维度
            return self.layer_norm(output + Q), attn  # 这里加Q 实际是对Q寻找
    
    
    class PoswiseFeedForwardNet(nn.Module):
    
        def __init__(self, d_model, d_ff):
            super(PoswiseFeedForwardNet, self).__init__()
            self.l1 = nn.Linear(d_model, d_ff)
            self.l2 = nn.Linear(d_ff, d_model)
    
            self.relu = GELU()
            self.layer_norm = nn.LayerNorm(d_model)
    
        def forward(self, inputs):
            residual = inputs
            output = self.l1(inputs)  # 一层线性卷积
            output = self.relu(output)
            output = self.l2(output)  # 一层线性卷积
            return self.layer_norm(output + residual)
    
    
    class EncoderLayer(nn.Module):
    
        def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
            super(EncoderLayer, self).__init__()
            self.enc_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
    
        def forward(self, enc_inputs, enc_self_attn_mask):
            enc_outputs, attn = self.enc_self_attn(Q=enc_inputs, K=enc_inputs, V=enc_inputs, attn_mask=enc_self_attn_mask)
            # X=Y 因此Q K V相等
            enc_outputs = self.pos_ffn(enc_outputs)  #
            return enc_outputs, attn
    
    
    class Encoder(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
            #                   4        128     256   64   64     8        4          0
            super(Encoder, self).__init__()
            self.device = device
            self.pad_index = pad_index
            self.src_emb = nn.Embedding(vocab_size, d_model)
            # vocab_size:词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999) d_model:嵌入向量的维度,即用多少维来表示一个符号
            self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
    
            self.layers = []
            for _ in range(n_layers):
                encoder_layer = EncoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
                self.layers.append(encoder_layer)
            self.layers = nn.ModuleList(self.layers)
    
        def forward(self, x):
            enc_outputs = self.src_emb(x)  # 词嵌入
            enc_outputs = self.pos_emb(enc_outputs)  # pos+matx
            enc_self_attn_mask = get_attn_pad_mask(x, x, self.pad_index)
    
            enc_self_attns = []
            for layer in self.layers:
                enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
                enc_self_attns.append(enc_self_attn)
    
            enc_self_attns = torch.stack(enc_self_attns)
            enc_self_attns = enc_self_attns.permute([1, 0, 2, 3, 4])
            return enc_outputs, enc_self_attns
    
    
    class DecoderLayer(nn.Module):
    
        def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
            super(DecoderLayer, self).__init__()
            self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.dec_enc_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
    
        def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
            dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
            dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
            dec_outputs = self.pos_ffn(dec_outputs)
            return dec_outputs, dec_self_attn, dec_enc_attn
    
    
    class Decoder(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
            super(Decoder, self).__init__()
            self.pad_index = pad_index
            self.device = device
            self.tgt_emb = nn.Embedding(vocab_size, d_model)
            self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
            self.layers = []
            for _ in range(n_layers):
                decoder_layer = DecoderLayer(d_model=d_model, d_ff=d_ff, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
                self.layers.append(decoder_layer)
            self.layers = nn.ModuleList(self.layers)
    
        def forward(self, dec_inputs, enc_inputs, enc_outputs):
            dec_outputs = self.tgt_emb(dec_inputs)
            dec_outputs = self.pos_emb(dec_outputs)
    
            dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
            dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
            dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
            dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs, self.pad_index)
    
            dec_self_attns, dec_enc_attns = [], []
            for layer in self.layers:
                dec_outputs, dec_self_attn, dec_enc_attn = layer(
                    dec_inputs=dec_outputs,
                    enc_outputs=enc_outputs,
                    dec_self_attn_mask=dec_self_attn_mask,
                    dec_enc_attn_mask=dec_enc_attn_mask)
                dec_self_attns.append(dec_self_attn)
                dec_enc_attns.append(dec_enc_attn)
            dec_self_attns = torch.stack(dec_self_attns)
            dec_enc_attns = torch.stack(dec_enc_attns)
    
            dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
            dec_enc_attns = dec_enc_attns.permute([1, 0, 2, 3, 4])
    
            return dec_outputs, dec_self_attns, dec_enc_attns
    
    
    class MaskedDecoderLayer(nn.Module):
    
        def __init__(self, d_model, d_ff, d_k, d_v, n_heads, device):
            super(MaskedDecoderLayer, self).__init__()
            self.dec_self_attn = MultiHeadAttention(d_model=d_model, d_k=d_k, d_v=d_v, n_heads=n_heads, device=device)
            self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, d_ff=d_ff)
    
        def forward(self, dec_inputs, dec_self_attn_mask):
            dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
            dec_outputs = self.pos_ffn(dec_outputs)
            return dec_outputs, dec_self_attn
    
    
    class MaskedDecoder(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff, d_k,
                     d_v, n_heads, n_layers, pad_index, device):
            super(MaskedDecoder, self).__init__()
            self.pad_index = pad_index
            self.tgt_emb = nn.Embedding(vocab_size, d_model)
            self.pos_emb = PositionalEncoding(d_model=d_model, dropout=0)
    
            self.layers = []
            for _ in range(n_layers):
                decoder_layer = MaskedDecoderLayer(
                    d_model=d_model, d_ff=d_ff,
                    d_k=d_k, d_v=d_v, n_heads=n_heads,
                    device=device)
                self.layers.append(decoder_layer)
            self.layers = nn.ModuleList(self.layers)
    
        def forward(self, dec_inputs):
            dec_outputs = self.tgt_emb(dec_inputs)
            dec_outputs = self.pos_emb(dec_outputs)
    
            dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs, self.pad_index)
            dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
            dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
            dec_self_attns = []
            for layer in self.layers:
                dec_outputs, dec_self_attn = layer(
                    dec_inputs=dec_outputs,
                    dec_self_attn_mask=dec_self_attn_mask)
                dec_self_attns.append(dec_self_attn)
            dec_self_attns = torch.stack(dec_self_attns)
            dec_self_attns = dec_self_attns.permute([1, 0, 2, 3, 4])
            return dec_outputs, dec_self_attns
    
    
    class BertModel(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff, d_k, d_v, n_heads, n_layers, pad_index, device):
            super(BertModel, self).__init__()
            self.tok_embed = nn.Embedding(vocab_size, d_model)
            self.pos_embed = PositionalEncoding(d_model=d_model, dropout=0)
            self.seg_embed = nn.Embedding(2, d_model)
    
            self.layers = []
            for _ in range(n_layers):
                encoder_layer = EncoderLayer(
                    d_model=d_model, d_ff=d_ff,
                    d_k=d_k, d_v=d_v, n_heads=n_heads,
                    device=device)
                self.layers.append(encoder_layer)
            self.layers = nn.ModuleList(self.layers)
    
            self.pad_index = pad_index
    
            self.fc = nn.Linear(d_model, d_model)
            self.active1 = nn.Tanh()
            self.classifier = nn.Linear(d_model, 2)
    
            self.linear = nn.Linear(d_model, d_model)
            self.active2 = GELU()
            self.norm = nn.LayerNorm(d_model)
    
            self.decoder = nn.Linear(d_model, vocab_size, bias=False)
            self.decoder.weight = self.tok_embed.weight
            self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))
    
        def forward(self, input_ids, segment_ids, masked_pos):
            output = self.tok_embed(input_ids) + self.seg_embed(segment_ids)
            output = self.pos_embed(output)
            enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.pad_index)
    
            for layer in self.layers:
                output, enc_self_attn = layer(output, enc_self_attn_mask)
    
            h_pooled = self.active1(self.fc(output[:, 0]))
            logits_clsf = self.classifier(h_pooled)
    
            masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1))
            h_masked = torch.gather(output, 1, masked_pos)
            h_masked = self.norm(self.active2(self.linear(h_masked)))
            logits_lm = self.decoder(h_masked) + self.decoder_bias
    
            return logits_lm, logits_clsf, output
    
    
    class GPTModel(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff,
                     d_k, d_v, n_heads, n_layers, pad_index,
                     device):
            super(GPTModel, self).__init__()
            self.decoder = MaskedDecoder(
                vocab_size=vocab_size,
                d_model=d_model, d_ff=d_ff,
                d_k=d_k, d_v=d_v, n_heads=n_heads,
                n_layers=n_layers, pad_index=pad_index,
                device=device)
            self.projection = nn.Linear(d_model, vocab_size, bias=False)
    
        def forward(self, dec_inputs):
            dec_outputs, dec_self_attns = self.decoder(dec_inputs)
            dec_logits = self.projection(dec_outputs)
            return dec_logits, dec_self_attns
    
    
    class Classifier(nn.Module):
    
        def __init__(self, vocab_size, d_model, d_ff,
                     d_k, d_v, n_heads, n_layers,
                     pad_index, device, num_classes):
            super(Classifier, self).__init__()
            self.encoder = Encoder(
                vocab_size=vocab_size,
                d_model=d_model, d_ff=d_ff,
                d_k=d_k, d_v=d_v, n_heads=n_heads,
                n_layers=n_layers, pad_index=pad_index,
                device=device)
            self.projection = nn.Linear(d_model, num_classes)
    
        def forward(self, enc_inputs):
            enc_outputs, enc_self_attns = self.encoder(enc_inputs)
            mean_enc_outputs = torch.mean(enc_outputs, dim=1)
            logits = self.projection(mean_enc_outputs)
            return logits, enc_self_attns
    
    
    class Translation(nn.Module):
    
        def __init__(self, src_vocab_size, tgt_vocab_size, d_model,
                     d_ff, d_k, d_v, n_heads, n_layers, src_pad_index,
                     tgt_pad_index, device):
            super(Translation, self).__init__()
            self.encoder = Encoder(
                vocab_size=src_vocab_size,  # 5
                d_model=d_model, d_ff=d_ff,  # 128  256
                d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
                n_layers=n_layers, pad_index=src_pad_index,  # 4  0
                device=device)
            self.decoder = Decoder(
                vocab_size=tgt_vocab_size,  # 5
                d_model=d_model, d_ff=d_ff,  # 128  256
                d_k=d_k, d_v=d_v, n_heads=n_heads,  # 64 64  8
                n_layers=n_layers, pad_index=tgt_pad_index,  # 4  0
                device=device)
            self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
    
        # def forward(self, enc_inputs, dec_inputs, decode_lengths):
        #     enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        #     dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        #     dec_logits = self.projection(dec_outputs)
        #     return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns, decode_lengths
    
        def forward(self, enc_inputs, dec_inputs):
            enc_outputs, enc_self_attns = self.encoder(enc_inputs)
            dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
            dec_logits = self.projection(dec_outputs)
            return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns
    
    
    if __name__ == '__main__':
        enc_input = [
            [1, 3, 4, 1, 2, 3],
            [1, 3, 4, 1, 2, 3],
            [1, 3, 4, 1, 2, 3],
            [1, 3, 4, 1, 2, 3]]
        dec_input = [
            [1, 0, 0, 0, 0, 0],
            [1, 3, 0, 0, 0, 0],
            [1, 3, 4, 0, 0, 0],
            [1, 3, 4, 1, 0, 0]]
        enc_input = torch.as_tensor(enc_input, dtype=torch.long).to(torch.device('cpu'))
        dec_input = torch.as_tensor(dec_input, dtype=torch.long).to(torch.device('cpu'))
        model = Translation(
            src_vocab_size=5, tgt_vocab_size=5, d_model=128,
            d_ff=256, d_k=64, d_v=64, n_heads=8, n_layers=4, src_pad_index=0,
            tgt_pad_index=0, device=torch.device('cpu'))
    
        logits, _, _, _ = model(enc_input, dec_input)
        print(logits)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416

    总结

    本文已全部介绍完transformer结构原理及代码,但我个人有以下几点说明:
    编码传递K V 解码传递Q;
    self-attention 和 cross attention本质是X与Y值不同,即得到Q 和 K V 数据来源不同,但实现方法一致;
    transformer重点模块为attention(一般是mutil-head attention)、FFN、位置编码、mask编码;

  • 相关阅读:
    2022 Android 高级进阶学习资料与高频精选面试题精讲(圆梦大厂)
    业务:财务会计业务知识
    .Net Core `RabbitMQ`封装
    SpringBoot使用配置文件若干方式
    数据可视化工具 ,不会写 SQL 代码也能做数据分析
    品优购项目案例制作需要注意的内容笔记
    BT - Unet:生物医学图像分割的自监督学习框架
    线性代数的本质(个人笔记)
    Linux基础指令(三)
    qt中d指针
  • 原文地址:https://blog.csdn.net/weixin_38252409/article/details/133840760