• Gather-Excite Attention


    paper:Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks

    前言

    SENET作者的另一篇注意力机制的文章,和SENET以及BAM、CBAM的思想差不多,其实不用看文章,看下面的代码和结构图就知道具体的实现过程了。

    本文的切入点在于context exploitation,其实就是空间维度的注意力机制,和BAM、CBAM的区别在于BAM在空间维度使用的普通卷积,最后输出的是单通道的特征图,因此在与原始特征图进行element-wise multiplication时每一个像素在所有通道上的权重是相同的。而GENET中使用的是深度卷积,最后的输出特征图通道数和原始输入一致,因此同一位置在不同通道上的权重是不同的。

    文中设计了GE的多种不同结构,GEθ" role="presentation" style="position: relative;">GEθgather的过程是通过全局平均池化实现的,因此没有增加额外需要训练的参数。GEθ" role="presentation" style="position: relative;">GEθ中gather是通过深度卷积实现的。GEθ+" role="presentation" style="position: relative;">GEθ+则是结合了SE的思想,在GEθ" role="presentation" style="position: relative;">GEθ的后面先通过一个1x1卷积缩减通道,再通过一个1x1卷积还原回去。

    文中提到的extent ratio和SE中的reduction ratio差不多,SE中r=16,第一个1x1卷积后通道数减为1/16。第二个1x1卷积再还原回去。GE中当e=8时,GEθ" role="presentation" style="position: relative;">GEθ是一个stride=8的平均池化,池化的kernel size可以自己设置。GEθ+" role="presentation" style="position: relative;">GEθ+中的gather则是通过3个stride=2的dwconv+bn+relu实现的。若e=16,则前者平均池化的stride=16,后者则是堆叠4个stride=2的dwconv+bn+relu,以此类推。

    实现代码

    下面的代码是timm中的实现

    1. """ Gather-Excite Attention Block
    2. Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
    3. Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
    4. I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
    5. impl that covers all of the cases.
    6. NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
    7. Hacked together by / Copyright 2021 Ross Wightman
    8. """
    9. import math
    10. from torch import nn as nn
    11. import torch.nn.functional as F
    12. from .create_act import create_act_layer, get_act_layer
    13. from .create_conv2d import create_conv2d
    14. from .helpers import make_divisible
    15. from .mlp import ConvMlp
    16. class GatherExcite(nn.Module):
    17. """ Gather-Excite Attention Module
    18. """
    19. def __init__(
    20. self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
    21. rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
    22. act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
    23. super(GatherExcite, self).__init__()
    24. self.add_maxpool = add_maxpool
    25. act_layer = get_act_layer(act_layer)
    26. self.extent = extent
    27. if extra_params:
    28. self.gather = nn.Sequential()
    29. if extent == 0:
    30. assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
    31. self.gather.add_module(
    32. 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
    33. if norm_layer:
    34. self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
    35. else:
    36. assert extent % 2 == 0
    37. num_conv = int(math.log2(extent))
    38. for i in range(num_conv):
    39. self.gather.add_module(
    40. f'conv{i + 1}',
    41. create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
    42. if norm_layer:
    43. self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
    44. if i != num_conv - 1:
    45. self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
    46. else:
    47. self.gather = None
    48. if self.extent == 0:
    49. self.gk = 0
    50. self.gs = 0
    51. else:
    52. assert extent % 2 == 0
    53. self.gk = self.extent * 2 - 1
    54. self.gs = self.extent
    55. if not rd_channels:
    56. rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
    57. self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
    58. self.gate = create_act_layer(gate_layer)
    59. def forward(self, x):
    60. size = x.shape[-2:]
    61. if self.gather is not None:
    62. x_ge = self.gather(x)
    63. else:
    64. if self.extent == 0:
    65. # global extent
    66. x_ge = x.mean(dim=(2, 3), keepdims=True)
    67. if self.add_maxpool:
    68. # experimental codepath, may remove or change
    69. x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
    70. else:
    71. x_ge = F.avg_pool2d(
    72. x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
    73. if self.add_maxpool:
    74. # experimental codepath, may remove or change
    75. x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
    76. x_ge = self.mlp(x_ge)
    77. if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
    78. x_ge = F.interpolate(x_ge, size=size)
    79. return x * self.gate(x_ge)

    结构图

    实验结果

  • 相关阅读:
    信息安全公司 DataExpert 利用 OpenText EnCase 解决方案帮助执法部门收集和分析数字证据
    中小企业如何选择进销存软件?
    [附源码]java毕业设计某互联网公司人力资源管理系统
    Python AI 绘画
    全网最牛自动化测试框架系列之pytest(9)-标记用例(指定执行、跳过用例、预期失败)
    Scala基础语法(一)
    zabbix监控Linux
    「 每日一练,快乐水题 」1608. 特殊数组的特征值
    激动人心,2022开放原子全球开源峰会报名火热开启
    mysq 主从同步错误之 Error_code 1032 handler error HA_ERR_KEY_NOT_FOUND
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/126077007