• vision transformer的计算复杂度


    Vision transformer

    在这里插入图片描述

    假设每个图像有 h ∗ w h*w hw 个patch,维度是 C C C

    输入的图像 X X X ( 大小为 h w ∗ C hw* C hwC ),和三个系数矩阵相乘 ( 大小为 C ∗ C C*C CC ),得到 q k v qkv qkv 三个向量 ( h w ∗ C hw*C hwC ),复杂度为:
    3 h w C 2 3hwC^2 3hwC2

    q q q ( h w ∗ C hw*C hwC ) 和 k T k^T kT ( C ∗ h w C*hw Chw ) 相乘得到矩阵 A A A ( h w ∗ h w hw*hw hwhw ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C

    A A A ( h w ∗ h w hw*hw hwhw ) 和 v v v ( h w ∗ C hw*C hwC )相乘,得到多头注意力的结果 ( h w ∗ C hw*C hwC ),复杂度为: ( h w ) 2 C (hw)^2C (hw)2C

    经过MLP投影层 ( C ∗ C C*C CC ),得到 ( h w ∗ C hw*C hwC ),复杂度为:
    h w C 2 hwC^2 hwC2

    所以复杂度之和为: 4 h w C 2 + 2 ( h w ) 2 C 4hwC^2 + 2(hw)^2C 4hwC2+2(hw)2C

    Swin transformer

    在这里插入图片描述
    基于滑动窗口的多头注意力,是在每个窗口内计算注意力

    假设每个窗口有 M × M M×M M×M 个patch

    在一个窗口内的复杂度为:

    4 M 2 C 2 + 2 M 4 C 4M^2C^2+2M^4C 4M2C2+2M4C

    共有 h w / M 2 hw /M^2 hw/M2 个窗口,所以复杂度之和为:

    4 h w C 2 + 2 M 2 h w C 4hwC^2+2M^2hwC 4hwC2+2M2hwC

    Convolutional vision Transformer

    使用 s × s s×s s×s 卷积进行卷积投影,有 h w hw hw 个patch,通道维度为 C C C

    输入的图像 X X X ( 大小为 h w ∗ C hw* C hwC ),使用三个标准卷积进行投影 ( 大小为 s ∗ s ∗ C s*s*C ssC ),得到 q k v qkv qkv 三个向量 ( h w ∗ C hw*C hwC ),投影的复杂度为:

    3 h w s 2 C 2 3hws^2C^2 3hws2C2

    使用深度可分离卷积,投影的复杂度为:

    3 h w s 2 C 3hws^2C 3hws2C

    使用步长大于1的卷积进行多头注意力的投影,减小后面注意力的计算花销。

    key和value的步长为2,query的步长为1,key和value的token数量减小了4倍,所以后续的多头注意力计算花销也减小了4倍。

    Cross Attention Transformer

    在这里插入图片描述

    交叉注意力包括IPSA和CPSA,IPSA在单个patch内使用卷积进行投影,CPSA在单个通道计算patch间的注意力

    IPSA的复杂度:

    patch大小为 N N N,通道数为 C C C

    输入的图像 X X X ( 大小为 N 2 ∗ C N^2* C N2C ),使用卷积进行投影 ( 大小为 1 ∗ 1 ∗ C 1*1*C 11C ),得到 q k v qkv qkv 三个向量 ( N 2 ∗ C N^2*C N2C ),复杂度为:
    3 N 2 C 2 3N^2C^2 3N2C2

    q q q ( N 2 ∗ C N^2*C N2C ) 和 k k k ( C ∗ N 2 C*N^2 CN2 ) 相乘得到矩阵 A A A ( N 2 ∗ N 2 N^2*N^2 N2N2 ),复杂度为: N 4 C 2 N^4C^2 N4C2

    A A A ( N 2 ∗ N 2 N^2*N^2 N2N2 ) 和 v v v ( N 2 ∗ C N^2*C N2C )相乘,得到多头注意力的结果 ( N 2 ∗ C N^2*C N2C ),复杂度为: N 4 C 2 N^4C^2 N4C2

    经过MLP投影层 ( C ∗ C C*C CC ),得到 ( N 2 ∗ C N^2*C N2C ),复杂度为:
    N 2 C 2 N^2C^2 N2C2

    单个patch内的复杂度为:

    4 N 2 C 2 + 2 N 4 C 2 4N^2C^2+2N^4C^2 4N2C2+2N4C2

    共有 H W / N 2 HW/N^2 HW/N2 个patch,所以IPSA总复杂度为:
    4 H W C 2 + 2 N 2 H W C 2 4HWC^2+2N^2HWC^2 4HWC2+2N2HWC2

    CPSA的复杂度:

    patch数目为 H W / N 2 HW/N^2 HW/N2,patch大小为 N 2 N^2 N2

    输入的图像 X X X ( 大小为 H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),和三个系数矩阵相乘 ( 大小为 N 2 ∗ N 2 N^2*N^2 N2N2 ),得到 q k v qkv qkv 三个向量 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为:
    3 H W N 2 3HWN^2 3HWN2

    q q q ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ) 和 k k k ( N 2 ∗ H W / N 2 N^2*HW/N^2 N2HW/N2 ) 相乘得到矩阵 A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2HW/N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2

    A A A ( H W / N 2 ∗ H W / N 2 HW/N^2*HW/N^2 HW/N2HW/N2 ) 和 v v v ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 )相乘,得到多头注意力的结果 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为: ( H W ) 2 / N 2 (HW)^2/N^2 (HW)2/N2

    经过MLP投影层 ( N 2 ∗ N 2 N^2*N^2 N2N2 ),得到 ( H W / N 2 ∗ N 2 HW/N^2*N^2 HW/N2N2 ),复杂度为:
    H W N 2 HWN^2 HWN2

    单个通道内的复杂度为:

    4 N 2 H W + 2 ( H W / N ) 2 4N^2HW+2(HW/N)^2 4N2HW+2(HW/N)2

    共有 C C C 个通道,所以CPSA总复杂度为:
    4 N 2 H W C + 2 ( H W / N ) 2 C 4N^2HWC+2(HW/N)^2C 4N2HWC+2(HW/N)2C

  • 相关阅读:
    Hbase regionserver频繁突然挂掉的问题处理
    【Python】Python调试器pdb
    深度学习基础网络整理----AlexNet
    create® 3入门教程-设置NTP
    【模糊神经网络】基于模糊神经网络的移动机器人路径规划
    入手不亏,4款简单易用的典藏软件,真正的电脑利器
    windows查看电脑配置
    【序列化与反序列化】关于序列化与反序列化MessagePack的实践
    【Mysql】清理binlog日志的方法
    华为云项目部署
  • 原文地址:https://blog.csdn.net/weixin_43772166/article/details/130915428