在深度学习中,我们经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。想象一下,有了注意力机制之后,我们将词元序列输入注意力池化后,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出
由于查询、键和值来自同一组输入,因此被称为自注意力(self-attention),也被称为内部注意力(intra-attention)。在本节中,我们将使用自注意力进行序列编码,以及如何使用序列的顺序作为补充信息
import math
import torch
from torch import nn
from d2l import torch as d2l
num_hiddens,num_heads = 100,5
attention = d2l.MultiHeadAttention(num_hiddens,num_hiddens,num_hiddens,num_hiddens,num_heads,0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size,num_queries,valid_lens = 2,4,torch.tensor([3,2])
X = torch.ones(batch_size,num_queries,num_hiddens)
attention(X,X,X,valid_lens).shape
torch.Size([2, 4, 100])
让我们⽐较下⾯⼏个架构,⽬标都是将由n个词元组成的序列映射到另⼀个⻓度相等的序列,其中的每个输⼊词元或输出词元都由d维向量表⽰。具体来说,我们将⽐较的是卷积神经⽹络、循环神经⽹络和⾃注意⼒这⼏个架构的计算复杂性、顺序操作和最⼤路径⻓度。请注意,顺序操作会妨碍并⾏计算,⽽任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self,num_hiddens,dropout,max_len=1000):
super(PositionalEncoding,self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1,max_len,num_hiddens))
X = torch.arange(max_len,dtype=torch.float32).reshape(-1,1)/torch.pow(10000,torch.arange(0,num_hiddens,2,dtype=torch.float32) / num_hiddens)
self.P[:,:,0::2] = torch.sin(X)
self.P[:,:,1::2] = torch.cos(X)
def forward(self,X):
X = X + self.P[:,:X.shape[1],:].to(X.device)
return self.dropout(X)
在位置嵌入矩阵P中,行代表词元在序列中的位置,列代表位置编码的不同维度。在下面的例子中,我们可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替
encoding_dim,num_steps = 32,60
pos_encoding = PositionalEncoding(encoding_dim,0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1,num_steps,encoding_dim)))
P = pos_encoding.P[:,:X.shape[1],:]
d2l.plot(torch.arange(num_steps),P[0,:,6:10].T,xlabel='Row (position)',figsize=(6,2.5),legend=["Col %d" % d for d in torch.arange(6,10)])
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0GTxej1f-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124216.svg)]
为了明白沿着编码维度单调降低的频率于绝对位置信息的关系,让我们打印出0,1…,7的二进制表示形式。正如我们所看到的,每个数字、每两个数字和每四个数字的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替
for i in range(8):
print(f'{i}的二进制是:{i:>03b}')
0的二进制是:000
1的二进制是:001
2的二进制是:010
3的二进制是:011
4的二进制是:100
5的二进制是:101
6的二进制是:110
7的二进制是:111
在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间
P = P[0,:,:].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9Ihh6y9U-1663075900559)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209132124217.svg)]