• 各种注意力机制,Attention、MLP、ReP等系列的PyTorch实现,含核心代码


    不知道CV方向的同学在读论文的时候有没有发现这样一个问题:论文的核心思想很简单,但当你找这篇论文的核心代码时发现,作者提供的源码模块会嵌入到分类、检测、分割等任务框架中,这时候如果你对某一特定框架不熟悉,尽管核心代码只有十几行,依然会发现很难找出。

    今天我就帮大家解决一部分这个问题,还记得上次分享的attention论文合集吗?没印象的同学点这里。

    这次总结了这30篇attention论文中的核心代码分享,还有一部分其他系列的论文,比如ReP、卷积级数等,核心代码与原文都整理了。

    由于篇幅和时间原因,暂时只分享了一部分,需要全部论文以及完整核心代码的同学看文末

    Attention论文

    1、Axial Attention in Multidimensional Transformers
    核心代码
    1. from model.attention.Axial_attention import AxialImageTransformer
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(312877)
    5.     model = AxialImageTransformer(
    6.         dim = 128,
    7.         depth = 12,
    8.         reversible = True
    9.     )
    10.     outputs = model(input)
    11.     print(outputs.shape)
    2、CCNet: Criss-Cross Attention for Semantic Segmentation
    核心代码
    1. from model.attention.CrissCrossAttention import CrissCrossAttention
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(36477)
    5.     model = CrissCrossAttention(64)
    6.     outputs = model(input)
    7.     print(outputs.shape)
    3、Aggregating Global Features into Local Vision Transformer
    核心代码
    1. from model.attention.MOATransformer import MOATransformer
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(1,3,224,224)
    5.     model = MOATransformer(
    6.         img_size=224,
    7.         patch_size=4,
    8.         in_chans=3,
    9.         num_classes=1000,
    10.         embed_dim=96,
    11.         depths=[226],
    12.         num_heads=[3612],
    13.         window_size=14,
    14.         mlp_ratio=4.,
    15.         qkv_bias=True,
    16.         qk_scale=None,
    17.         drop_rate=0.0,
    18.         drop_path_rate=0.1,
    19.         ape=False,
    20.         patch_norm=True,
    21.         use_checkpoint=False
    22.     )
    23.     output=model(input)
    24.     print(output.shape)
    4、CROSSFORMER: A VERSATILE VISION TRANSFORMER HINGING ON CROSS-SCALE ATTENTION
    核心代码
    1. from model.attention.Crossformer import CrossFormer
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(1,3,224,224)
    5.     model = CrossFormer(img_size=224,
    6.         patch_size=[481632],
    7.         in_chans= 3,
    8.         num_classes=1000,
    9.         embed_dim=48,
    10.         depths=[2262],
    11.         num_heads=[361224],
    12.         group_size=[7777],
    13.         mlp_ratio=4.,
    14.         qkv_bias=True,
    15.         qk_scale=None,
    16.         drop_rate=0.0,
    17.         drop_path_rate=0.1,
    18.         ape=False,
    19.         patch_norm=True,
    20.         use_checkpoint=False,
    21.         merge_size=[[24], [2,4], [24]]
    22.     )
    23.     output=model(input)
    24.     print(output.shape)
    5、Vision Transformer with Deformable Attention
    核心代码
    1. from model.attention.DAT import DAT
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(1,3,224,224)
    5.     model = DAT(
    6.         img_size=224,
    7.         patch_size=4,
    8.         num_classes=1000,
    9.         expansion=4,
    10.         dim_stem=96,
    11.         dims=[96192384768],
    12.         depths=[2262],
    13.         stage_spec=[['L''S'], ['L''S'], ['L''D''L''D''L''D'], ['L''D']],
    14.         heads=[361224],
    15.         window_sizes=[7777] ,
    16.         groups=[-1, -136],
    17.         use_pes=[FalseFalseTrueTrue],
    18.         dwc_pes=[FalseFalseFalseFalse],
    19.         strides=[-1, -111],
    20.         sr_ratios=[-1, -1, -1, -1],
    21.         offset_range_factor=[-1, -122],
    22.         no_offs=[FalseFalseFalseFalse],
    23.         fixed_pes=[FalseFalseFalseFalse],
    24.         use_dwc_mlps=[FalseFalseFalseFalse],
    25.         use_conv_patches=False,
    26.         drop_rate=0.0,
    27.         attn_drop_rate=0.0,
    28.         drop_path_rate=0.2,
    29.     )
    30.     output=model(input)
    31.     print(output[0].shape)
    6、Separable Self-attention for Mobile Vision Transformers
    核心代码
    1. from model.attention.MobileViTv2Attention import MobileViTv2Attention
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. if __name__ == '__main__':
    6.     input=torch.randn(50,49,512)
    7.     sa = MobileViTv2Attention(d_model=512)
    8.     output=sa(input)
    9.     print(output.shape)
    7、On the Integration of Self-Attention and Convolution
    核心代码
    1. from model.attention.ACmix import ACmix
    2. import torch
    3. if __name__ == '__main__':
    4.     input=torch.randn(50,256,7,7)
    5.     acmix = ACmix(in_planes=256, out_planes=256)
    6.     output=acmix(input)
    7.     print(output.shape)
    8、Non-deep Networks
    核心代码
    1. from model.attention.ParNetAttention import *
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. if __name__ == '__main__':
    6.     input=torch.randn(50,512,7,7)
    7.     pna = ParNetAttention(channel=512)
    8.     output=pna(input)
    9.     print(output.shape) #50,512,7,7
    9、UFO-ViT: High Performance Linear Vision Transformer without Softmax
    核心代码
    1. from model.attention.UFOAttention import *
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. if __name__ == '__main__':
    6.     input=torch.randn(50,49,512)
    7.     ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
    8.     output=ufo(input,input,input)
    9.     print(output.shape) #[5049512]
    10、Coordinate Attention for Efficient Mobile Network Design
    核心代码
    1. from model.attention.CoordAttention import CoordAtt
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. inp=torch.rand([2965656])
    6. inp_dim, oup_dim = 9696
    7. reduction=32
    8. coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
    9. output=coord_attention(inp)
    10. print(output.shape)

    ReP论文

    1、RepVGG: Making VGG-style ConvNets Great Again
    核心代码
    1. from model.rep.repvgg import RepBlock
    2. import torch
    3. input=torch.randn(50,512,49,49)
    4. repblock=RepBlock(512,512)
    5. repblock.eval()
    6. out=repblock(input)
    7. repblock._switch_to_deploy()
    8. out2=repblock(input)
    9. print('difference between vgg and repvgg')
    10. print(((out2-out)**2).sum())
    2、ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks
    核心代码
    1. from model.rep.acnet import ACNet
    2. import torch
    3. from torch import nn
    4. input=torch.randn(50,512,49,49)
    5. acnet=ACNet(512,512)
    6. acnet.eval()
    7. out=acnet(input)
    8. acnet._switch_to_deploy()
    9. out2=acnet(input)
    10. print('difference:')
    11. print(((out2-out)**2).sum())

    卷积级数论文

    1、CondConv: Conditionally Parameterized Convolutions for Efficient Inference
    核心代码
    1. from model.conv.CondConv import *
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. if __name__ == '__main__':
    6.     input=torch.randn(2,32,64,64)
    7.     m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
    8.     out=m(input)
    9.     print(out.shape)
    2、Dynamic Convolution: Attention over Convolution Kernels
    核心代码
    1. from model.conv.DynamicConv import *
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. if __name__ == '__main__':
    6.     input=torch.randn(2,32,64,64)
    7.     m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
    8.     out=m(input)
    9.     print(out.shape) # 2,32,64,64
    3、Involution: Inverting the Inherence of Convolution for Visual Recognition
    核心代码
    1. from model.conv.Involution import Involution
    2. import torch
    3. from torch import nn
    4. from torch.nn import functional as F
    5. input=torch.randn(1,4,64,64)
    6. involution=Involution(kernel_size=3,in_channel=4,stride=2)
    7. out=involution(input)
    8. print(out.shape)

    关注下方《学姐带你玩AI》🚀🚀🚀

    回复“核心代码”获取全部论文+代码合集

    码字不易,欢迎大家点赞评论收藏!

  • 相关阅读:
    9、Neural Sparse Voxel Fields
    Mybatis-plus 自动生成代码
    卷积神经网络相比循环神经网络具有哪些特征
    【微服务部署】04-ForwardedHeaders
    Linux学习-71-GRUB手动安装方法
    sqlmap中文文档
    动态规划问题(六)
    接口测试(jmeter和postman 接口使用)
    若依(ruoyi)之thymeleaf与jsp共存解决方案
    [入门到吐槽系列] Webix 10分钟入门 一 管理后台制作
  • 原文地址:https://blog.csdn.net/weixin_42645636/article/details/133277420