输入矩阵I:首先计算
Q
=
I
∗
W
q
Q =I * W^{q}
Q=I∗Wq,
K
=
I
∗
W
k
K = I * W^{k}
K=I∗Wk,
V
=
I
∗
W
v
V = I * W^{v}
V=I∗Wv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
Q
K
T
QK^{T}
QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
softmax:
A
=
Q
K
T
A=QK^{T}
A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes*
b
h
s
2
bhs^{2}
bhs2=
2
b
h
s
2
2bhs^{2}
2bhs2 bytes。
dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为
b
h
s
2
bhs^{2}
bhs2 bytes。
score* V加权:score矩阵的形状与A相同,占用显存大小为
2
b
h
s
2
2bhs^{2}
2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为
2
b
h
s
2
+
2
b
s
d
2bhs^{2}+2bsd
2bhs2+2bsd bytes。
W
O
W^{O}
WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。 综上步骤,self-attention块的占用显存大小为2bsd+4bsd+
2
b
h
s
2
2bhs^{2}
2bhs2+
2
b
h
s
2
2bhs^{2}
2bhs2+
2
b
h
s
2
+
2
b
s
d
2bhs^{2}+2bsd
2bhs2+2bsd+2bsd+2bsd=11bsd+
5
b
h
s
2
5bhs^{2}
5bhs2