• 注意力机制 -自注意力和位置编码


    自注意力和位置编码

    在深度学习中,我们经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。想象一下,有了注意力机制之后,我们将词元序列输入注意力池化后,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出

    由于查询、键和值来自同一组输入,因此被称为自注意力(self-attention),也被称为内部注意力(intra-attention)。在本节中,我们将使用自注意力进行序列编码,以及如何使用序列的顺序作为补充信息

    import math
    import torch
    from torch import nn
    from d2l import torch as d2l
    
    • 1
    • 2
    • 3
    • 4

    1 - 自注意力

    num_hiddens,num_heads = 100,5
    attention = d2l.MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,num_heads,0.5)
    attention.eval()
    
    • 1
    • 2
    • 3
    MultiHeadAttention(
      (attention): DotProductAttention(
        (dropout): Dropout(p=0.5, inplace=False)
      )
      (W_q): Linear(in_features=100, out_features=100, bias=False)
      (W_k): Linear(in_features=100, out_features=100, bias=False)
      (W_v): Linear(in_features=100, out_features=100, bias=False)
      (W_o): Linear(in_features=100, out_features=100, bias=False)
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    batch_size,num_queries,valid_lens = 2,4,torch.tensor([3,2])
    X = torch.ones(batch_size,num_queries,num_hiddens)
    attention(X,X,X,valid_lens).shape
    
    • 1
    • 2
    • 3
    torch.Size([2, 4, 100])
    
    • 1

    2 - 比较卷积神经网络、循环神经网络和自注意力

    让我们⽐较下⾯⼏个架构,⽬标都是将由n个词元组成的序列映射到另⼀个⻓度相等的序列,其中的每个输⼊词元或输出词元都由d维向量表⽰。具体来说,我们将⽐较的是卷积神经⽹络、循环神经⽹络和⾃注意⼒这⼏个架构的计算复杂性、顺序操作和最⼤路径⻓度。请注意,顺序操作会妨碍并⾏计算,⽽任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系

    3 - 位置编码

    class PositionalEncoding(nn.Module):
        """位置编码"""
        def __init__(self,num_hiddens,dropout,max_len=1000):
            super(PositionalEncoding,self).__init__()
            self.dropout = nn.Dropout(dropout)
            # 创建一个足够长的P
            self.P = torch.zeros((1,max_len,num_hiddens))
            X = torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/torch.pow(10000,torch.arange(0,num_hiddens,2,dtype=torch.float32) / num_hiddens)
            self.P[:,:,0::2] = torch.sin(X)
            self.P[:,:,1::2] = torch.cos(X)
            
        def forward(self,X):
            X = X + self.P[:,:X.shape[1],:].to(X.device)
            return self.dropout(X)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    在位置嵌入矩阵P中,行代表词元在序列中的位置,列代表位置编码的不同维度。在下面的例子中,我们可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替

    encoding_dim,num_steps = 32,60
    pos_encoding = PositionalEncoding(encoding_dim,0)
    pos_encoding.eval()
    
    X = pos_encoding(torch.zeros((1,num_steps,encoding_dim)))
    P = pos_encoding.P[:,:X.shape[1],:]
    d2l.plot(torch.arange(num_steps),P[0,:,6:10].T,xlabel='Row (position)',figsize=(6,2.5),legend=["Col %d" % d for d in torch.arange(6,10)])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7


    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0GTxej1f-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124216.svg)]

    绝对位置信息

    为了明白沿着编码维度单调降低的频率于绝对位置信息的关系,让我们打印出0,1…,7的二进制表示形式。正如我们所看到的,每个数字、每两个数字和每四个数字的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替

    for i in range(8):
        print(f'{i}的二进制是:{i:>03b}')
    
    • 1
    • 2
    0的二进制是:000
    1的二进制是:001
    2的二进制是:010
    3的二进制是:011
    4的二进制是:100
    5的二进制是:101
    6的二进制是:110
    7的二进制是:111
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间

    P = P[0,:,:].unsqueeze(0).unsqueeze(0)
    d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
    
    • 1
    • 2


    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9Ihh6y9U-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124217.svg)]

    相对位置信息

    4 - 小结

    • 在自注意力中,查询、键和值都来自同一组输入
    • 卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢
    • 为了适应序列的顺序信息,我们可以通过在输入表示中添加位置编码,来注入绝对的或相对的位置信息
  • 相关阅读:
    4.keepalive 与 Idle 监测
    建议收藏丨你想了解的动捕内容全在这儿!
    leetcode 136. 只出现一次的数字(异或!!)
    基于模板匹配的图像拼接技术研究-含Matlab代码
    8.25 学习
    Linux系统下安装和卸载Redis
    检查floating pin
    ipad触控笔有必要买原装吗?ipad2023手写笔推荐
    vue-cli-service: command not found问题解决
    什么是生成式人工智能?人工智能创造
  • 原文地址:https://blog.csdn.net/mynameisgt/article/details/126842259