这是一篇硬核的优化Transformer的工作。众所周知,Transformer模型的计算量和储存复杂度是 O ( N 2 ) O(N^2) O(N2) 。尽管先前有了大量的优化工作,比如LongFormer、Sparse Transformer、Reformer等等,一定程度上减轻了Transformer的资源消耗,但对Transformer的性能有所折损,且扩展性不强,不能泛化到其它领域、以及复杂结构的叠加。
这篇工作从底层对Transformer的计算和读写进行了优化,主要有三个贡献:

flash attention的思路就是尽量地在SRAM中进行分块计算、算子融合,减少对HBM(即常说的显存)的读写,从加快模型计算,减轻内存墙问题。

# ---------------------
# Tc: K和V的分块数
# Tr: Q的分块数量
# ---------------------
for 1 <= j <= Tc:
for 1 <= i <= Tr:
do....

由于对 Q , K Q, K Q,K矩阵进行了分块,就无法进行全局归一化。我们的最终目的是得到 O O O ,作者这里根据公式推导,不断用当前最新的rowmax和rowsum去更新,直到遍历完最后一块,最终结果就和标准场景下的结果完全一致。




可以看到,flash-attention通过算子融合、分块计算减少了IO,内存墙问题得以缓解。