transformer作为一种特征提取器,在NLP中有广泛的应用。但是Trm需要对输入序列设置一个固定的长度,比如在BERT中,默认长度是512。如果文本序列长度短于固定长度,可以通过填充的方式来解决。如果序列长度超过固定长度,处理起来就比较麻烦。一种处理方式,就是将文本划分为多个segments。训练的时候,对每个segment单独处理,segments之间没有联系,如下图(a)所示。这存在两个问题,1)因为segments之间独立训练,所以不同的token之间,最长的依赖关系,就取决于segment的长度;2)出于效率的考虑,在划分segments的时候,不考虑句子的自然边界,而是根据固定的长度来划分序列,导致分割出来的segments在语义上是不完整的。

在预测的时候,会对固定长度的segment做计算,一般取最后一个位置的隐向量作为输出。为了充分利用上下文关系,在每做完一次预测之后,就对整个序列向右移动一个位置,再做一次计算,如上图(b)所示,这导致计算效率非常低。
为了解决上面提到的问题,在Trm的基础上,Trm-XL提出了一个改进,在对当前segment进行处理的时候,缓存并利用上一个segment中所有layer的隐向量序列,而且上一个segment的所有隐向量序列只参与前向计算,不再进行反向传播,这就是所谓的segment-level Recurrence。
Trm本身是可以设置multi-heads,但是在后文中为了简化描述采用单个head。将两个连续的segments表示为
L是序列长度假设整个模型中,包含N层Trm,那么每个segment中就有N组长度为L的隐向量序列,将第
τ
\tau
τ个segment的第n层隐向量序列表示为
h
τ
n
∈
R
L
×
d
h_{\tau}^{n}\in R^{L\times d}
hτn∈RL×d,d是隐向量维度.那么第
τ
+
1
\tau+1
τ+1个segment的第n层隐向量序列,可以由下面的一组公式计算得出。
h
~
τ
+
1
n
−
1
=
[
S
G
(
h
τ
n
−
1
)
,
h
τ
+
1
n
−
1
]
(
表示对两个向量的拼接
,
拼接后为
2
L
×
d
)
q
τ
+
1
n
,
k
τ
+
1
n
,
v
τ
+
1
n
=
h
τ
+
1
n
−
1
W
q
T
,
h
~
τ
+
1
n
−
1
W
k
T
,
h
~
τ
+
1
n
−
1
W
v
T
h
τ
+
1
n
−
1
=
T
r
a
n
s
f
o
r
m
e
r
L
a
y
e
r
(
q
τ
+
1
n
,
k
τ
+
1
n
,
v
τ
+
1
n
)
\tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \qquad (表示对两个向量的拼接,拼接后为2L\times d) \\ \qquad \\ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \\ \qquad \\ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n)
h~τ+1n−1=[SG(hτn−1),hτ+1n−1](表示对两个向量的拼接,拼接后为2L×d)qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h~τ+1n−1WkT,h~τ+1n−1WvThτ+1n−1=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)
注意q的计算方式不变,只使用当前segment中的隐向量,计算得到的q序列长度仍然是L。k和v采用拼接之后的
h
~
\tilde{h}
h~来计算,计算出来的序列长度是2L。之后的计算就是标准的Transformer计算。计算出来的第n层隐向量序列长度仍然是L,而不是2L。Trm的输出隐向量序列长度取决于query的序列长度,而不是key和value。
推导一下:
训练和预测过程如下图所示。这张图上有一个点需要注意,在当前segment中,第n层的每个隐向量的计算,都是利用下一层中包括当前位置在内的,连续前L个长度的隐向量,这是在上面的公式组中没有体现出来的,也是文中没有明说的。每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的token存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),如下图中Evaluation phase所示,所以最长的依赖关系长度是N(L-1),N是模型中layer的数量。N通常要比L小很多,比如在BERT中,N=12或者24,L=512,依赖关系长度可以近似为 O ( N × L ) O(N\times L) O(N×L) 。在对长文本进行计算的时候,可以缓存上一个segment的隐向量的结果,不必重复计算,大幅提高计算效率。

上文中,我们只保存了上一个segment,实际操作的时候,可以保存尽可能多的segments,只要内存或者显存放得下。论文中的试验在训练的时候,只缓存一个segment,在预测的时候,会缓存多个segments。
在vanilla Trm中,为了表示序列中token的顺序关系,在模型的输入端,对每个token的输入embedding,加一个位置embedding。位置编码embedding或者采用正弦\余弦函数来生成,或者通过学习得到。在Trm-XL中,这种方法行不通,每个segment都添加相同的位置编码,多个segments之间无法区分位置关系。Trm-XL放弃使用绝对位置编码,而是采用相对位置编码,在计算当前位置隐向量的时候,考虑与之依赖token的相对位置关系。具体操作是,在算attention score的时候,只考虑query向量与key向量的相对位置关系,并且将这种相对位置关系,加入到每一层Trm的attention的计算中。
我们对两种方法做个对比。下面一组公式是vanilla Trm计算attention的方式,
E
x
E_x
Ex表示token的输入embedding,U是绝对位置编码embedding,两个W分别是query矩阵和key矩阵。下面的公式是对
(
E
x
i
+
U
i
)
W
q
T
W
k
(
E
x
j
+
U
j
)
(E_{x_i}+U_i)W_q^TW_k(E_{x_j}+U_j)
(Exi+Ui)WqTWk(Exj+Uj)做了分解。
A
i
,
j
a
b
s
=
E
x
i
T
W
q
T
W
K
E
x
j
+
E
x
i
T
W
q
T
W
K
U
j
+
U
i
T
W
q
T
W
K
E
x
j
+
U
i
T
W
q
T
W
K
U
j
A_{i,j}^{abs} = E_{x_i}^TW_q^TW_KE_{x_j} + E_{x_i}^TW_q^TW_KU_j + U_{i}^TW_q^TW_KE_{x_j} + U_{i}^TW_q^TW_KU_{j}
Ai,jabs=ExiTWqTWKExj+ExiTWqTWKUj+UiTWqTWKExj+UiTWqTWKUj
下面一组公式,是Trm-XL计算attention的方式。首先,将绝对位置编码U,替换成了相对位置编码
R
i
−
j
R_{i-j}
Ri−j 。插一句,因为i只利用之前的序列,所以i-j>=0。并且把
W
k
W_k
Wk矩阵分为
W
k
,
E
和
W
k
,
R
W_{k,E}和W_{k,R}
Wk,E和Wk,R,用于分别生成基于内容的key向量和基于位置的key向量,
A
i
,
j
r
e
l
=
E
x
i
T
W
q
T
W
k
,
E
E
x
j
+
E
x
i
T
W
q
T
W
k
,
R
R
i
−
j
+
U
i
T
W
q
T
W
k
,
E
E
x
j
+
U
j
T
W
q
T
W
k
,
R
R
i
−
j
A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j}
Ai,jrel=ExiTWqTWk,EExj+ExiTWqTWk,RRi−j+UiTWqTWk,EExj+UjTWqTWk,RRi−j
相对位置关系用一个位置编码矩阵 R ∈ R L m a x × d R\in R^{L_{max}\times d} R∈RLmax×d 来表示,第i行表示相对位置间隔为i的位置向量。论文中强调R采用正弦函数生成,而不是通过学习得到的,好处是预测时,可以使用比训练距离更长的位置向量。
最后来看一下Trm-XL的完整计算公式,如下所示,只有前3行与vanilla Trm不同,后3行是一样的。第3行公式中,计算A的时候直接采用query向量,而不再使用 表示。最后需要注意的是,每一层在计算attention的时候,都要包含相对位置编码。而在vanilla Trm中,只有在输入embedding中才包含绝对位置编码,在中间层计算的时候,是不包含位置编码的。
h ~ τ + 1 n − 1 = [ S G ( h τ n − 1 ) , h τ + 1 n − 1 ] q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q T , h ~ τ + 1 n − 1 W k T , h ~ τ + 1 n − 1 W v T h τ + 1 n − 1 = T r a n s f o r m e r L a y e r ( q τ + 1 n , k τ + 1 n , v τ + 1 n ) A i , j r e l = E x i T W q T W k , E E x j + E x i T W q T W k , R R i − j + U i T W q T W k , E E x j + U j T W q T W k , R R i − j α τ n = M a s k e d S o f t m a x ( A τ n ) V τ n o τ n = L a y e r N o r m ( L i n e a r ( α τ n ) + h τ + 1 n − 1 ) h τ n = P o s i t i o n w i s e F e e d F o r w a r d ( o τ n ) \tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \\ \qquad \\ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \\ \qquad \\ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n) \\ \qquad \\ A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j} \\ \qquad \\ \alpha_{\tau}^n = Masked\quad Softmax(A_{\tau}^n)V_{\tau}^n \\ \qquad \\ o_{\tau}^n = LayerNorm(Linear(\alpha_{\tau}^n)+{h}_{\tau+1}^{n-1}) \\ \qquad \\ h_{\tau}^n = Positionwise\quad Feed\quad Forward(o_{\tau}^n) h~τ+1n−1=[SG(hτn−1),hτ+1n−1]qτ+1n,kτ+1n,vτ+1n=hτ+1n−1WqT,h~τ+1n−1WkT,h~τ+1n−1WvThτ+1n−1=TransformerLayer(qτ+1n,kτ+1n,vτ+1n)Ai,jrel=ExiTWqTWk,EExj+ExiTWqTWk,RRi−j+UiTWqTWk,EExj+UjTWqTWk,RRi−jατn=MaskedSoftmax(Aτn)Vτnoτn=LayerNorm(Linear(ατn)+hτ+1n−1)hτn=PositionwiseFeedForward(oτn)
总结,Trm-XL为了解决长序列的问题,对上一个segment做了缓存,可供当前segment使用,但是也带来了位置关系问题,为了解决位置问题,又打了个补丁,引入了相对位置编码。