1个矩阵计算量:对于输入I,首先计算
Q
=
I
∗
W
q
Q =I * W^{q}
Q=I∗Wq,
K
=
I
∗
W
k
K = I * W^{k}
K=I∗Wk,
V
=
I
∗
W
v
V = I * W^{v}
V=I∗Wv,假设输入I的形状为 [b, s, d],1个矩阵乘法的输入和输出形状为[b, s, d] × [d, d] = [b, s, d],计算量为
2
b
s
d
2
2bsd^{2}
2bsd2
3个矩阵计算量:
6
b
s
d
2
6bsd^{2}
6bsd2
3.1.2 attention计算
Q
K
T
QK^{T}
QKT
矩阵乘法的输入形状[b, h, s, d] × [b, h, s, d],输出形状为 [b, h, s, s],h维度是concat,没有计算量,因此该步骤的计算量为
2
b
s
2
d
2bs^{2}d
2bs2d 。
score*V加权 输入形状为[b, h, s, s] × [b, h, s, d],输出形状为[b, h, s, d], h维度是concat,没有计算量,因此该步骤的计算量为
2
b
s
2
d
2bs^{2}d
2bs2d 。
3.1.3 MultiHeadAttention输出线性映射
所有head都concat,输入形状为[b, s, d] × [d, d]
(
W
O
)
(W^{O})
(WO),输出形状为[b, s, d],计算量
2
b
s
d
2
2bsd^{2}
2bsd2
3.1.4 MultiHeadAttention总计算量
MultiHeadAttention总计算量为上面三部分之和,
6
b
s
d
2
6bsd^{2}
6bsd2+
2
b
s
2
d
2bs^{2}d
2bs2d+
2
b
s
2
d
2bs^{2}d
2bs2d+
2
b
s
d
2
2bsd^{2}
2bsd2=
4
b
s
2
d
4bs^{2}d
4bs2d+
8
b
s
d
2
8bsd^{2}
8bsd2
3.2 MLP
MLP内包含2个线性层:
第一个线性层,矩阵乘法输入形状为[b, s, d] × [d, 4d],输出形状为[b, s, 4d],计算量
8
b
s
d
2
8bsd^{2}
8bsd2 。
第二个线性层,矩阵乘法输入形状为[b, s, 4d] × [4d, d],输出形状为[b, s, d],计算量
8
b
s
d
2
8bsd^{2}
8bsd2
MLP总计算量为
8
b
s
d
2
8bsd^{2}
8bsd2+
8
b
s
d
2
8bsd^{2}
8bsd2=
16
b
s
d
2
16bsd^{2}
16bsd2
3.3 projection输出
logits的计算,将隐藏向量映射为词表大小。矩阵乘法输入形状为[b, s, d] × [d, v],输出形状为[b, s, v],计算量
2
b
s
d
v
2bsdv
2bsdv。
Transformer的输出为1个projection 将上面3部分累加,计算量为N*(
4
b
s
2
d
4bs^{2}d
4bs2d+
8
b
s
d
2
8bsd^{2}
8bsd2+
16
b
s
d
2
16bsd^{2}
16bsd2)+N*(2*(
4
b
s
2
d
4bs^{2}d
4bs2d+
8
b
s
d
2
8bsd^{2}
8bsd2)+
16
b
s
d
2
16bsd^{2}
16bsd2)+
2
b
s
d
v
2bsdv
2bsdv=
12
N
b
s
2
d
12Nbs^{2}d
12Nbs2d+
56
N
b
s
d
2
56Nbsd^{2}
56Nbsd2+
2
b
s
d
v
2bsdv
2bsdv