• pytorch代码实现之动态卷积模块ODConv


    ODConv动态卷积模块

    ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能。

    原文地址:Omni-Dimensional Dynamic Convolution

    ODConv结构图
    代码实现:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.autograd
    from models.common import Conv, autopad
    
    class Attention(nn.Module):
        def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
            super(Attention, self).__init__()
            attention_channel = max(int(in_planes * reduction), min_channel)
            self.kernel_size = kernel_size
            self.kernel_num = kernel_num
            self.temperature = 1.0
    
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            self.fc = Conv(in_planes, attention_channel, act=nn.ReLU(inplace=True))
    
            self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
            self.func_channel = self.get_channel_attention
    
            if in_planes == groups and in_planes == out_planes:  # depth-wise convolution
                self.func_filter = self.skip
            else:
                self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
                self.func_filter = self.get_filter_attention
    
            if kernel_size == 1:  # point-wise convolution
                self.func_spatial = self.skip
            else:
                self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
                self.func_spatial = self.get_spatial_attention
    
            if kernel_num == 1:
                self.func_kernel = self.skip
            else:
                self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
                self.func_kernel = self.get_kernel_attention
    
            self._initialize_weights()
    
        def _initialize_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                if isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
        def update_temperature(self, temperature):
            self.temperature = temperature
    
        @staticmethod
        def skip(_):
            return 1.0
    
        def get_channel_attention(self, x):
            channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
            return channel_attention
    
        def get_filter_attention(self, x):
            filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
            return filter_attention
    
        def get_spatial_attention(self, x):
            spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
            spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
            return spatial_attention
    
        def get_kernel_attention(self, x):
            kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
            kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
            return kernel_attention
    
        def forward(self, x):
            x = self.avgpool(x)
            x = self.fc(x)
            return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)
    
    
    class ODConv2d(nn.Module):
        def __init__(self, in_planes, out_planes, k, s=1, p=None, g=1, act=True, d=1,
                     reduction=0.0625, kernel_num=1):
            super(ODConv2d, self).__init__()
            self.in_planes = in_planes
            self.out_planes = out_planes
            self.kernel_size = k
            self.stride = s
            self.padding = autopad(k, p)
            self.dilation = d
            self.groups = g
            self.kernel_num = kernel_num
            self.attention = Attention(in_planes, out_planes, k, groups=g,
                                       reduction=reduction, kernel_num=kernel_num)
            self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//g, k, k),
                                       requires_grad=True)
            self._initialize_weights()
            self.bn = nn.BatchNorm2d(out_planes)
            self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
    
            if self.kernel_size == 1 and self.kernel_num == 1:
                self._forward_impl = self._forward_impl_pw1x
            else:
                self._forward_impl = self._forward_impl_common
    
        def _initialize_weights(self):
            for i in range(self.kernel_num):
                nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')
    
        def update_temperature(self, temperature):
            self.attention.update_temperature(temperature)
    
        def _forward_impl_common(self, x):
            # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
            # while we observe that when using the latter method the models will run faster with less gpu memory cost.
            channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
            batch_size, in_planes, height, width = x.size()
            x = x * channel_attention
            x = x.reshape(1, -1, height, width)
            aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
            aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
                [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
            output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups * batch_size)
            output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
            output = output * filter_attention
            return output
    
        def _forward_impl_pw1x(self, x):
            channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
            x = x * channel_attention
            output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups)
            output = output * filter_attention
            return output
    
        def forward(self, x):
            return self.act(self.bn(self._forward_impl(x)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
  • 相关阅读:
    Redis7.0 编译安装以及简单创建Cluster测试服务器的方法 步骤
    Qt多线程及线程池的使用笔记
    进销存系统哪个好?2023最新进销存系统推荐
    Redis常见异常汇总
    web:[极客大挑战 2019]LoveSQL
    Android动态更换图标
    SIP中继与VoIP:有何不同?
    使用Process Monitor工具监测进程对注册表和文件的操作
    MYSQL 敏感数据加密后进行模糊查询
    CentOS7.4下gSOAP-2.8编译
  • 原文地址:https://blog.csdn.net/DM_zx/article/details/132857530