• FlashAttention2原理解析以及面向AIGC的加速实践


    8ce6efb9aa1c2f2e6099288873b64c7d.gif

    FlashAttention-2提出后,便得到了大量关注。本文将具体讲述FlashAttention-2的前世今生,包括FlashAttention1&2的原理解析、加速效果比较以及面向AIGC的加速实践,在这里将相关内容与大家分享~

    f7bbb91c17ededdda1af6d5255b1c971.png

    引言

    将 Transformers 扩展到更长的序列长度一直是过去几年的一个热点问题,这将有助于提高语言建模和高分辨率图像理解的能力,也有利于音频和视频生成方面的新应用场景研发。Attention层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存占用是序列长度的二次方。使用近似计算的Attention方法,可以通过减少FLOP计算次数、甚至于牺牲模型质量来降低计算复杂性,但通常无法实现大比例的加速。

    

    由斯坦福大学提出的FlashAttention方法,让使用更长sequence计算Attention成为可能,并且通过线性级别的增长来节省内存以及加速计算。因为FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。

    dbd8363cae7842820a6453a41522c451.png

    FlashAttention

      主要内容

    FlashAttention主要关注IO-aware,进一步优化GPU显存的读写效率。这是一种 IO 感知的精确Attention算法,它使用tiling(这里可以理解为分块)来减少 GPU 高带宽内存 (HBM) 和 GPU 片上 SRAM 之间的内存读/写次数。这里的HBM可以理解为显存,SRAM可以理解为cache。通过测试IO复杂性,相比标准 Attention,FlashAttention需要更少的 HBM 访问,并且对于不同的SRAM 大小来说都是有效的。除此以外,FlashAttention还可以扩展到block-sparse attention,产生比任何现有近似注意力方法更快的近似注意力算法。

    

    FlashAttention与 MLPerf 1.1 训练速度相比,对于BERT-large(序列长度 512)实现端到端wall-clock加速15%,对于GPT-2(序列长度 1K)加速 3 倍。FlashAttention 和block-sparse FlashAttention 可在 Transformers 中实现更长的上下文,从而产生更高质量的模型,GPT-2 上的困惑度提升0.7,长文档分类的test结果提高 6.4 个点。

      主要操作

    0cb25904f6963be4b5aabd042c8c5388.png

    背景知识:

    上图的左图,表示存储结构,可以简单理解为:SRAM表示缓存,HBM表示显存,DRAM表示内存。

    • tiling

    在不访问整个输入的情况下优化attention计算,并减少相关计算量。重构attention计算,将输入分割成块,并对分块进行多次传递,从而逐步执行attention计算(该步骤称为tiling)。

    

    如上图所示,FlashAttention 使用tiling来防止在相对较慢的 GPU显存上实现大型 𝑁 × 𝑁 注意力矩阵(虚线框)计算。在外部循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块,并将它们加载到快速片上 SRAM。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。

    

    将输入Q、K、V矩阵分成很多块,将它们从较慢的HBM加载到较快的SRAM,然后在SRAM计算关于这些块的注意力输出。对每个块的计算结果缩放之后进行add操作,则得到正确的结果,具体伪代码如图:

    e4eba9ea225d6e5d57300f53be96a7ce.png

    • recomputing

    FlashAttention不专门存储用于后向计算的大型中间计算结果。在SRAM中存储前向计算中的 softmax 归一化因子,以便在后向传递计算梯度的时候快速得到中间结果,这比从 HBM 读取中间计算结果的标准方法更快。

    

    FlashAttention不存储前向计算中𝑂(𝑁2)复杂度的中间值,但是后向传递通常需要矩阵 S, P ∈ R𝑁 ×𝑁 来计算相对于 Q、K、V 的梯度。通过存储输出 O 和 softmax 归一化统计量 (𝑚, ℓ),则可以在 SRAM 中的 Q、K、V 块的后向计算中快速重新计算注意力矩阵 S 和 P。这可以看作是checkpoint的一种形式。

    

    如下图所示,FlashAttention由于tiling分块操作和recomputing操作,增加了一些计算次数。但是还是通过使用SRAM减少了显存占用,通过减少hbm访问次数加快了attention计算。

    2db7df6bb2639426b15009394c3f0ac4.png

      Block-Sparse FlashAttention

    b84764a4708e670aa3332acfae3e586a.png

    论文还提出了Block-sparse FlashAttention,其IO复杂度比FlashAttention小,与稀疏度成正比。非0矩阵越少(即0矩阵越多),稀疏化来压缩数据的空间就越大,block-sparse加速就越明显。上图中,IO复杂度与稀疏性成正比,随着稀疏性的增加(非0矩阵增加),Block-sparse FlashAttention的运行时间成比例地提高。

      小结

    总的来说,FlashAttention有如下优点:

    1. hbm访问次数降低,所以计算更快

    2. 在sram中计算attention,并对于后向计算提前保留中间结果,所以显存占用更少

    3. 可以使用更长的sequence,使得模型训练效果更好

    4. 对于attention计算,加速明显。如果加上稀疏化处理,速度会更快。

    b6f52e8556555aaacefd295f3c2b6643.png

    FlashAttention-2


      主要内容

    FlashAttention 的整体速度仍然和单独进行矩阵乘法 (GEMM) 的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。作者观察到效率低下的原因是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。最新提出 FlashAttention-2,通过更好的工作分区来解决这些问题,主要包含的操作:1.调整算法以减少非矩阵乘运算的计算次数。2.跨不同线程块进行并行化注意力计算。3.在每个线程块内, 在 warp 之间优化工作分配以减少共享内存的通信。

    

    与FlashAttention 相比,FlashAttention-2速度提高了约 2 倍,达到 A100 上理论最大 FLOPs/s 的 50-73%,接近 GEMM 操作的效率。根据经验验证,当使用端到端来训练 GPT 式模型时,FlashAttention-2 的训练速度高达每 A100 GPU 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。不同设置(有或没有causal mask、不同头部尺寸)的测试表明,FlashAttention-2 比 FlashAttention 实现了约 2 倍的加速,在前向传递中达到理论最大吞吐量的 73%,在后向传递中达到理论最大吞吐量的 63%。

      主要操作
    • 减少非矩阵运算

    调整算法以减少非 matmul(矩阵乘法) FLOP 的数量,同时不改变输出。虽然非 matmul FLOP 只占总 FLOP 的一小部分,但它们由于 GPU 具有专门的矩阵乘法单元,非矩阵乘法的运算需要更长的时间来执行,矩阵乘法吞吐量可以比非矩阵乘法吞吐量高出16倍。因此,减少非 matmul FLOP 并尽可能多的进行 matmul FLOP 非常重要。

    背景知识:

    吞吐量是指单位时间内完成的任务数量或数据处理量。在这个上下文中,吞吐量指的是执行矩阵乘法操作时的性能表现,以及执行其他非矩阵乘法操作时的性能表现。这句话的意思是,执行矩阵乘法操作时,系统能够以每单位时间处理更多的任务或数据,其数量可以高达非矩阵乘法操作时的16倍。这表明矩阵乘法操作在性能上比其他操作更加高效。

    Forward pass:优化qkv的softmax计算中非矩阵运算

    b7cd1d26ce8869247566b8bdd521d9d7.png

    e132b00a8872db1a16725da3ec4f8961.png

    Backward pass:FlashAttention-2的后向传递与FlashAttention几乎相同,主要区别在于需要进行梯度计算与更新。这里做了一个小调整,只使用求和结果𝐿,而不是 softmax 中的行式最大值和行式指数和。

    dc713ba4b2dea91bf791eb1b4c6a9f0a.png

    • 增加并行比例

    除了batchsize维度和head数目维度,还在序列长度维度上对前向传播和反向传播进行并行化处理,提高并行性。在序列较长的情况下,提高GPU资源的占用率。FlashAttention对于batchsize和head数目进行并行化处理,FlashAttention2基于序列长度进行并行化。当批量大小和head数量较小时,序列长度上增加的并行性有助于提高占用率(正在使用的 GPU 资源的比例),从而在这种情况下实现加速。

    0a8df0ea2a2bd604544c3c5ba06acc72.png

    Forward pass:对批量维度和头数维度进行并行化,如 FlashAttention 中所做的那样。对于外循环(在序列长度上),将它们调度到不需要彼此通信的不同线程块上,每个工作线程负责关注矩阵的一行block块。外循环每次处理一行block,内循环每次处理这一行中的一列block,这和FlashAttention处理方式是不同的。

    

    Backward pass:不同列块之间唯一共享的计算是算法 2 中更新的dQ,其中我们需要将 dQ从 HBM 加载到 SRAM,然后在片上通过 dQ更新,并写回 HBM。我们使用原子添加在不同线程块之间进行通信以更新 dQ。我们也在序列长度维度上进行并行化,并为后向传递的每一列block块安排 1 个工作线程(和前向传递是反过来的)。

    • 在warp上优化工作划分

    在一个注意力计算的block内,在一个thread block的不同warp之间优化工作划分,以减少通信和共享内存的读/写。

    09a5cf39fab3172ba61853d27d7d976c.png

    在每个线程块内,我们也必须决定如何在不同的 warp 之间划分工作。我们通常每个线程块使用 4 或 8 个 warp,分区如上图所示。

    

    Forward pass:对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。每个warp相乘得到 QK⊤ 的slice,然后它们需要与 V 的slice相乘并进行通信以将结果相加。这称为“split-K”方案。然而,这是低效的,因为所有 warp 都需要将其中间结果写入共享内存,进行同步,然后将中间结果相加。这些共享内存读/写会减慢 FlashAttention 中的前向传播速度。在 FlashAttention-2 中,我们将 Q 分成 4 个经线,同时保持所有经线均可访问 K 和 V。在每个扭曲执行矩阵乘法以获得 QK⊤ 切片后,它们只需与共享的 V 切片相乘即可获得相应的输出切片。warp 之间不需要通信。共享内存读/写的减少可以提高速度。

    背景知识:

    warp:由多个thread组成,是编程层面的概念。

    flash1:k和v被分为4个不同的warp,q和k计算、再和v计算,每一次计算的中间结果都要写入共享内存,并在之后被读取。这样就增加了共享内存的读写次数、拖慢了速度。

    flash2:将q分为4个不同的warp,然后计算qk、计算v。但是这里k和v不需要通信,所以计算v的时候,不需要新的内存读写。这样就减少了读写次数、加快了程序。

    Backward pass:对于后向传递,我们选择对warp进行分区以避免“split-K”方案,从而减少共享内存的读/写次数,并再次提高速度。由于所有不同输入和梯度 Q、K、V、O、dO、dQ、dK、dV 之间的依赖性更加复杂,它需要一些同步操作。

      小结

    FlashAttention-2可以加速attention计算。测量FlashAttention-2 在不同序列长度上的运行时间,并与 PyTorch、FlashAttention 和 Triton 中的 FlashAttention 中的标准实现进行比较。FlashAttention-2 比 FlashAttention 快 1.7-3.0 倍,比 Triton 中的 FlashAttention 快 1.3-2.5 倍,比标准注意力实现快 3-10 倍。

    

    FlashAttention-2可以加速端到端训练。当使用端到端在 2k 或 8k 序列长度上训练大小为 1.3B 和 2.7B 的 GPT 型模型时,FlashAttention-2 与 FlashAttention 相比可实现高达 1.3 倍的加速,与基线相比可实现 2.8 倍的加速 没有FlashAttention。每个 A100 GPU 的 FlashAttention-2 速度高达 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。

    

    FlashAttention-2 比 FlashAttention 快 2 倍,可以用之前训练 8k 上下文模型的时间,来训练具有 16k 更长上下文的模型。使用更长的context训练模型,可以更好理解长篇书籍和报告、高分辨率图像、音频和视频。

    193190dc4049fb2d2a3ef4c3a1a0b338.jpeg

    FlashAttention-2加速实践

      时间与显存的优化效果

    对于qkv计算,比较FlashAttention2与custom pytorch、xformers(FlashAttention1)的时间与显存消耗。如果只考虑QKV计算,flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些。

    1. flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些
    2. test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    3. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    4. custom pytorch time: 0.000754, peak memory: 113 MB
    5. flash attention time: 0.000103, speedup: 7.29; peak memory: 45 MB, save: 60%
    6. xformers time: 0.000255, speedup: 2.95; peak memory: 63 MB, save: 44%
    7. test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    8. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    9. custom pytorch time: 0.000703, peak memory: 131 MB
    10. flash attention time: 0.000106, speedup: 6.63; peak memory: 57 MB, save: 56%
    11. xformers time: 0.000252, speedup: 2.80; peak memory: 70 MB, save: 46%
    12. test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    13. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    14. custom pytorch time: 0.000721, peak memory: 131 MB
    15. flash attention time: 0.000106, speedup: 6.78; peak memory: 57 MB, save: 56%
    16. xformers time: 0.000263, speedup: 2.74; peak memory: 70 MB, save: 46%
    17. test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    18. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    19. custom pytorch time: 0.000704, peak memory: 131 MB
    20. flash attention time: 0.000105, speedup: 6.71; peak memory: 57 MB, save: 56%
    21. xformers time: 0.000249, speedup: 2.82; peak memory: 70 MB, save: 46%
    22. test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    23. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    24. custom pytorch time: 0.000700, peak memory: 131 MB
    25. flash attention time: 0.000110, speedup: 6.35; peak memory: 57 MB, save: 56%
    26. xformers time: 0.000254, speedup: 2.75; peak memory: 70 MB, save: 46%
    27. test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    28. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    29. custom pytorch time: 0.000766, peak memory: 131 MB
    30. flash attention time: 0.000106, speedup: 7.25; peak memory: 57 MB, save: 56%
    31. xformers time: 0.000252, speedup: 3.04; peak memory: 70 MB, save: 46%
    32. test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    33. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    34. custom pytorch time: 0.000684, peak memory: 131 MB
    35. flash attention time: 0.000101, speedup: 6.77; peak memory: 57 MB, save: 56%
    36. xformers time: 0.000268, speedup: 2.56; peak memory: 70 MB, save: 46%
    37. test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    38. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    39. custom pytorch time: 0.000717, peak memory: 131 MB
    40. flash attention time: 0.000110, speedup: 6.52; peak memory: 57 MB, save: 56%
    41. xformers time: 0.000254, speedup: 2.82; peak memory: 70 MB, save: 46%
    42. test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    43. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    44. custom pytorch time: 0.000700, peak memory: 131 MB
    45. flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
    46. xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
    47. test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    48. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    49. custom pytorch time: 0.000700, peak memory: 131 MB
    50. flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
    51. xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
    52. test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    53. shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
    54. custom pytorch time: 0.000721, peak memory: 131 MB
    55. flash attention time: 0.000102, speedup: 7.10; peak memory: 57 MB, save: 56%
    56. xformers time: 0.000251, speedup: 2.87; peak memory: 70 MB, save: 46%
      精度损失比较

    计算FlashAttention2对于注意力机制的精度损失,与pytorch的计算精度进行对比。

    

    绝大部分用例都可以通过测试,并且符合要求:

    1. dQ Pytorch mean diff: 0.000698089599609375
    2. dK Pytorch mean diff: 0.0005950927734375
    3. dV Pytorch mean diff: 0.000537872314453125
    4. .Actual dropout fraction: 0.17163611948490143
    5. Output max diff: 0.001953125
    6. Output mean diff: 2.9206275939941406e-05
    7. Pytorch max diff: 0.0029296875
    8. Pytorch mean diff: 8.106231689453125e-05
    9. Attention max diff: 0.000244140625
    10. Attention Pytorch max diff: 0.000732421875
    11. dQ max diff: 0.0025577545166015625
    12. dK max diff: 0.00390625
    13. dV max diff: 0.0078125
    14. dQ mean diff: 3.904104232788086e-05
    15. dK mean diff: 0.0001360177993774414
    16. dV mean diff: 0.0001475811004638672
    17. dQ Pytorch max diff: 0.00390625
    18. dK Pytorch max diff: 0.004150390625
    19. dV Pytorch max diff: 0.0078125
    20. dQ Pytorch mean diff: 8.702278137207031e-05
    21. dK Pytorch mean diff: 0.00025916099548339844
    22. dV Pytorch mean diff: 0.0002474784851074219
    23. .Actual dropout fraction: 0.17163611948490143
    24. Output max diff: 0.015625
    25. Output mean diff: 0.0002346038818359375
    26. Pytorch max diff: 0.015625
    27. Pytorch mean diff: 0.00064849853515625
    28. Attention max diff: 0.001953125
    29. Attention Pytorch max diff: 0.00390625
    30. dQ max diff: 0.01953125
    31. dK max diff: 0.033203125
    32. dV max diff: 0.0625
    33. dQ mean diff: 0.0003108978271484375
    34. dK mean diff: 0.00109100341796875
    35. dV mean diff: 0.0011749267578125
    36. dQ Pytorch max diff: 0.01806640625
    37. dK Pytorch max diff: 0.0390625
    38. dV Pytorch max diff: 0.0625
    39. dQ Pytorch mean diff: 0.00069427490234375
    40. dK Pytorch mean diff: 0.0020751953125
    41. dV Pytorch mean diff: 0.001953125
    42. ...

    少数用例不符合要求,有一定的精度损失:

    1. FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-False-dtype0] - assert False
    2. FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-True-dtype0] - assert False
    3. FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype0] - AssertionError: assert 236.75 <= (2 * 0.0009765625)
    4. FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype1] - AssertionError: assert 22144.0 <= (2 * 0.0078125)
    5. FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype0] - AssertionError: assert 2.724609375 <= (2 * 0.001953125)
    6. FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype1] - AssertionError: assert 95.5 <= (2 * 0.015625)

    FlashAttention2与参考方法的输出和梯度相比,误差很小并在可控范围内。对于不同的head dimensions, input dtype, sequence length, causal / non-causal,FlashAttention2的最大数值误差最多是 Pytorch的baseline中的数值误差的两倍。

    

    对于前向计算和后向计算。前向计算,是确定性的,每次测试结果可以复现。后向计算,非确定性(没有bit级别的确定性),每次结果可能有略微的不一样(比如输入不变,seed改变)。如果只做推理,则只涉及前向计算,所以计算是确定性的。

    6b6afe21142b34eb123f0c6aad296d15.jpeg

    FlashAttention-2加速AIGC


      环境信息

    NVIDIA A10, CUDA Version: 11.4, webui-1.5.1, eas推理平台

      加速效果

    xformers(flash1):


    文生图(512*512)(batchsize=1)文生图(512*512)(batchsize=4)
    unet耗时(s)

    1

    1

    1

    1

    4

    4

    4

    4

    unet耗时(it/s)

    (step = 20)

    11.11it/s

    11.27it/s

    11.27it/s

    11.27it/s

    4.33it/s

    4.33it/s

    4.33it/s

    4.33it/s

    xformers(flash2):


    文生图(512*512)文生图(512*512)(batchsize=4)
    unet耗时(s)

    1

    1

    1

    1

    4

    4

    4

    4

    unet耗时(it/s)

    (step = 20)

    11.13it/s

    11.75it/s

    11.46it/s

    11.92it/s

    4.69it/s

    4.69it/s

    4.69it/s

    4.68it/s

    相对于xformers(flash1),xformers(flash2)提速:


    unet过程提速
    文生图加速(一次生成1图)(11.57-11.23)/11.23=3%
    文生图加速(一次生成4图)(4.69-4.33)/4.33=8.3%
      精度比较

    xformers(flash1)

    文生图(512*512)_ouput1文生图(512*512)_ouput2
    023125b357e171a9c8c06ff00351eb71.pngcae4eb84711205b0bf1f00e205856712.png

    xformers(flash2)

    文生图(512*512)_ouput1文生图(512*512)_ouput2
    a9111547e95cb1139b99512b31028e85.png6fdbbeb02f8ce8655b4a9f4d2d61615c.png

    使用不同的加速方法,AIGC生成图像,均符合预期,无精度损失。

    注:这里未固定seed,所以图像会有变化,但是生成效果符合预期。

      AIGC加速分析

    使用flash_attention2,对sd加速,相比flash_attention1,加速比例并不高,或者说无法达到论文中那么高的加速比例。

    • SD模型自身特点

    flash_attention2主要是针对qkv计算进行加速,sd的推理过程中还有很多别的计算。推理过程中,进行采样(去噪),具有大量的计算,qkv计算只是推理计算的一部分。对于大图,计算量也更大,qkv的计算比例也更大,所以可以得到更多的加速效果。

    

    SD模型的网络结构:

    6f520c9c5655bbda037a0664afb8ef82.png

    • SD社区代码特点

    目前方法对于sd的提速,主要针对stable diffusion的神经网络本身。但是webui还有别的耗时:1.webui生图以外,还需要进行很多后处理,比如序列化反序列化、图片后处理、转换格式、传递图片等。2.webui是一个社区项目,兼容的功能非常多、而且杂,里面各种判断逻辑。这些操作拖慢了速度,端到端速度表现一般。

    • 显卡性能特殊性

    加速比例,对于不同的GPU效果不一样。3090这张卡比较特殊,计算性能好,但是显存的读取速度很差。所以在batchsize小的时候,性能卡在显存读取速度上,加速比高不起来。如果想看到更高的加速比,可以试试加大batchsize,这样diffusion占用的时间变多,网络和反序列化消耗时间的占比变小。如果用A系列卡效果会好一些。A10的性能比3090差,和他的显存读取速度匹配,加速方法对算法的优化比较符合卡的特点。

    cdddd733c98074900c799022f137a10d.jpeg

    FlashAttention-2与fastunet对于AIGC联合加速

    为了进一步优化aigc生图效率,使用webui更快速地进行加速,我们针对diffusion model特点,通过fastunet与FlashAttention-2结合的方式进行加速,并取得了相对于flash1已有加速效果的大于40%的提速。

      实验环境

    NVIDIA A10, CUDA Version: 11.4, webui-1.5.1, eas推理平台

      加速效果

    xformers(flash2)+fastunet


    文生图(512*512)文生图(512*512)(batchsize=4)
    unet耗时(s)

    1

    1

    1

    1

    3

    3

    3

    3

    unet耗时(it/s)

    (step = 20)

    17.06it/s

    18.22it/s

    17.36it/s

    16.43it/s

    6.26it/s

    6.27it/s

    6.25it/s

    6.25it/s

    相对于xformers(flash1),xformers(flash2)+fastunet提速:


    unet过程提速
    文生图加速(一次生成1图)(17.26-11.23)/11.23=54%
    文生图加速(一次生成4图)(6.26-4.33)/4.33=45%

    加速效果:flash attention2 + fastunet > flash attention2 > flash attention1

      精度比较

    使用xformers(flash2)+fastunet加速方法,AIGC生成图像,结果符合预期,无精度损失。

    文生图(512*512)_ouput1文生图(512*512)_ouput2
    8eb31c5bae113861aabefa53cbdda4de.png29c5bf7ef7d4663cf89db1b5b9e7d463.png

      AIGC加速分析

    生图过程主要有两部分耗时:controlnet与unet

    旧方法:xformers 0.0.20,使用flash attention1加速sd(unet+controlnet)

    新方法:1.当前的fastunet只加速unet里的attention(换为flash attention2)。2.xformers0.0.21加速包括controlnet在内的所有attention(换为flash attention2)。3.fastunett还对其他算子也做了一些fuse操作,也起到了加速效果。

    

    fastunet和xformers0.0.21加速的底层逻辑,都是使用flash attention2优化attention。fastunet和xformers0.0.21叠加使用,可以最大程度起到加速效果。新的加速方法主要针对attention计算进行优化,所以在unet及其attention部分会有更高比例的加速。

    51adb54a9e46443d2271ef394a5fdc66.jpeg

    总结与讨论

    近年来,让 Transformers 能够处理更长的序列长度一直备受关注。这一发展有助于提升语言建模和高分辨率图像理解的能力,并为音频和视频生成等新的应用场景带来了机遇。FlashAttention方法使得使用更长的序列计算注意力成为可能,并通过线性级别的增长来节省内存并加速计算。这一方法为处理长序列的Transformer模型提供了一种有效的解决方案。最新提出的FlashAttention-2,也进一步实现了attention计算速度的大幅度提升。

    

    当我们一直在关注GPU显存大小以及计算能力的时候,FlashAttention关注了GPU显存以外的SRAM,从而优化attention计算。也为我们解决问题提供了思考,即在主流关注的技术点以外,还有一些被忽视的但依旧可以解决问题的思路。面对实际效果与理论效果的差距,FlashAttention-2则进一步找到gap原因,通过关注矩阵运算、序列并行、工作分区等问题,优化计算效果。这也提醒我们,对于性能问题的解决,从软硬件结合的角度出发,才能更充分的解决问题。

    

    在AIGC领域的生图任务中,使用diffusion model进行相关计算,需要大量时间完成生图过程。所以,通过FlashAttention-2等多种加速方法进一步提升AIGC的生图效率,具有深刻意义。我们团队致力于家装行业AIGC进行相关研发,以提高家装AI模型的效果。我们希望与对此方向感兴趣的同学一起探讨和交流。

    a96db5f3d75c355297a398a1f1ab69b6.jpeg

    团队介绍

    我们是淘天集团-场景智能技术团队,一支专注于通过AI和3D技术驱动商业创新的技术团队, 依托大淘宝丰富的业务形态和海量的用户、数据, 致力于为消费者提供创新的场景化导购体验, 为商家提供高效的场景化内容创作工具, 为淘宝打造围绕家的场景的第一消费入口。我们不断探索并实践新的技术, 通过持续的技术创新和突破,创新用户导购体验, 提升商家内容生产力, 让用户享受更好的消费体验, 让商家更高效、低成本地经营。

    ¤ 拓展阅读 ¤

    3DXR技术 | 终端技术 | 音视频技术

    服务端技术 | 技术质量 | 数据算法

  • 相关阅读:
    SAP: 建立HTTPS 连接时,报错 ICM_HTTP_SSL_PEER_CERT_UNTRUSTED
    代码Bug太多?给新人Code Review头都大了?快来试试SpotBugs
    Embedding技术与应用 (2) :神经网络的发展及现代Embedding方法简介
    ML 线性回归原理推导以及灵魂拷问 (面试必考知识点)
    在线录音工具分享,总有一款适合你!
    【ASE入门学习】ASE入门系列二十四——轴向溶解
    net 获取本地ip地址,net mvc + net core 两种
    sql注入学习-知识点大合集
    [LMKD] [Android] 进程OomAdj调整分析:Empty被Kill流程(4)
    golang数据库连接池参数设置
  • 原文地址:https://blog.csdn.net/Taobaojishu/article/details/133366239