• Multihead Attention - 多头注意力


    多头注意力

    在实践中,当给定 相同的查询、键和值的集合 时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces) 可能是有益的。

    为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h h h 组不同的线性投影(linear projections) 来变换查询、键和值。然后,这 h h h 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h h h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h h h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)

    本质地讲,自注意力机制是:通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。

    自注意力机制的缺陷是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此作者提出了通过多头注意力机制来解决这一问题。

    下图展示了使用全连接层来实现可学习的线性变换的多头注意力。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-R7BJtkT1-1667357320669)(attachment:QQ%E6%88%AA%E5%9B%BE20221031074721.png)]

    模型

    在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} qRdq、键 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} kRdk和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} vRdv,每个注意力头 h i \mathbf{h}_i hi i = 1 , … , h i = 1, \ldots, h i=1,,h)的计算方法为:

    h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv,

    其中,可学习的参数包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)Rpq×dq W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)Rpk×dk W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)Rpv×dv,以及代表注意力汇聚的函数 f f f
    f f f 可以是之前学习的加性注意力缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h 个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} WoRpo×hpv

    W o [ h 1 ⋮ h h ] ∈ R p o . \mathbf W_o

    [h1hh]" role="presentation">[h1hh]
    \in \mathbb{R}^{p_o}. Wo h1hh Rpo.

    基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

    import math
    import torch
    from torch import nn
    from d2l import torch as d2l
    
    • 1
    • 2
    • 3
    • 4

    实现

    在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h 个头。在下面的实现中, p o p_o po是通过参数 num_hiddens 指定的。

    class MultiHeadAttention(nn.Module):
        """多头注意力"""
        def __init__(self, key_size, query_size, value_size, num_hiddens,
                    num_heads, dropout, bias=False, **kwargs):
            super(MultiHeadAttention, self).__init__(**kwargs)
            self.num_heads = num_heads
            self.attention = d2l.DotProductAttention(dropout)
            
            self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
            self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
            self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
            self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
            
        def forward(self, queries, keys, values, valid_lens):
            # queries, keys, values的形状:
            # (batch_size,查询或“键-值”对的个数,num_hiddens)
            # valid_len 的形状:
            # (batch_size,)或(batch_size,查询的个数)
            # 经过变换后,输出的queries,keys,values的形状:
            # (batch_size*num_heads,查询或“键-值”个数,num_hiddens/num_head)
            
            queries = transpose_qkv(self.W_q(queries), self.num_heads)
            keys = transpose_qkv(self.W_k(keys), self.num_heads)
            values = transpose_qkv(self.W_v(values), self.num_heads)
            
            if valid_lens is not None:
                # 在轴0,将第一项(标量或矢量) 复制 num_heads次,
                # 然后如此复制第二项,然后诸如此类
                valid_lens = torch.repeat_interleave(valid_lens,
                                                    repeats=self.num_heads,
                                                    dim=0)
            
            
            # output的形状:(batch_size*num_heads, 查询个数,num_hiddens/num_head)
            output = self.attention(queries, keys, values, valid_lens)
            # output_concat的形状:(batch_size, 查询个数,num_hiddens)
            output_concat = transpose_output(output, self.num_heads)
            return self.W_o(output_concat)
                
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    为了能够使多个头并行计算,上面的 MultiHeadAttention 类将使用下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作。

    def transpose_qkv(X, num_heads):
        """为了多头注意力的并行计算而变换形状"""
        # 输入X的形状(batch_size, 查询或”键-值“对的个数,num_hiddens)
        # 输出X的形状(batch_size,查询或”键-值“对的个数,
        # num_heads,num_hiddens/num_heads)
        X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
        
        # 输出X的形状(batch_size,
        # num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        
        # 输出X的形状(batch_size*num_heads,
        # 查询或”键-值“对的个数,num_hiddens/num_heads)
        return X.reshape(-1, X.shape[2], X.shape[3])
    
    
    def transpose_output(X, num_heads):
        """逆转transpose_qkv函数的操作"""
        # 输入X的形状(batch_size*num_heads,
        # 查询或”键-值“对的个数,num_hiddens/num_heads)
        
        # 输出X的形状(batch_size,
        # num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)
        X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
        
        # 输出X的形状(batch_size,查询或”键-值“对的个数,
        # num_heads,num_hiddens/num_heads)
        X = X.permute(0, 2, 1, 3)
        
        # 输出X的形状(batch_size,查询或”键-值“对的个数,num_hiddens)
        return X.reshape(X.shape[0], X.shape[1], -1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    下面我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size,num_queries, num_hiddens)。

    num_hiddens, num_heads = 100, 5
    attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                  num_hiddens, num_heads, 0.5)
    attention.eval()
    
    • 1
    • 2
    • 3
    • 4
    MultiHeadAttention(
      (attention): DotProductAttention(
        (dropout): Dropout(p=0.5, inplace=False)
      )
      (W_q): Linear(in_features=100, out_features=100, bias=False)
      (W_k): Linear(in_features=100, out_features=100, bias=False)
      (W_v): Linear(in_features=100, out_features=100, bias=False)
      (W_o): Linear(in_features=100, out_features=100, bias=False)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    batch_size, num_queries = 2, 4
    num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
    X = torch.ones((batch_size, num_queries, num_hiddens))
    Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
    attention(X, Y, Y, valid_lens).shape
    
    • 1
    • 2
    • 3
    • 4
    • 5
    torch.Size([2, 4, 100])
    
    • 1

    小结

    1、多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

    2、基于适当的张量操作,可以实现多头注意力的并行计算。

  • 相关阅读:
    Visual Assist v10.9.2471.0 Crack
    简单认识:结构体的嵌套,结构体的传参
    MIPI CSI接口调试方法: data rate计算
    Git设置初始化默认分支为main
    cartographer中的扫描匹配
    第一类曲面积分:曲面微元dσ与其投影面积微元dxdy之间的关系推导
    数字验证学习笔记——UVM学习1
    【Java分享客栈】一文搞定京东零售开源的AsyncTool,彻底解决异步编排问题。
    WebRTC系列-SDP之setLocalDescription(2)
    记研二首次组会汇报暨进组后第三次组会汇报
  • 原文地址:https://blog.csdn.net/weixin_43479947/article/details/127647636