• pytorch代码实现之空间通道重组卷积SCConv


    空间通道重组卷积SCConv

    空间通道重组卷积SCConv,全称Spatial and Channel Reconstruction Convolution,CPR2023年提出,可以即插即用,能够在减少参数的同时提升性能的模块。其核心思想是希望能够实现减少特征冗余从而提高算法的效率。一般压缩模型的方法分为三种,分别是network pruning, weight quantization, low-rank factorization以及knowledge distillation,虽然这些方法能够达到减少参数的效果,但是往往都会导致模型性能的衰减。另一种方法就是在构建模型时利用特殊的模块或操作减少模型参数,获得轻量级的网络模型,这种方法能够在保证性能的同时达到参数减少的效果。

    原文地址:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

    作者提出的SCConv包含两部分,分别是Spatial Reconstruction Unit (SRU)和Channel Reconstruction Unit (CRU),下面是SSConv的总体结构。
    SCConv结构原理图
    可以看出,SCConv模块设计,对于输入的特征图先利用1x1的卷积改变为适合的通道数,之后便分别是SRU和CRU两个模块对于特征图进行处理,最后在通过1x1的卷积将特征通道数恢复并进行残差操作。
    SRU模块结构
    CRU模块结构

    代码实现如下:

    import torch
    import torch.nn.functional as F
    import torch.nn as nn 
    
    
    class GroupBatchnorm2d(nn.Module):
        def __init__(self, c_num:int, 
                     group_num:int = 16, 
                     eps:float = 1e-10
                     ):
            super(GroupBatchnorm2d,self).__init__()
            assert c_num    >= group_num
            self.group_num  = group_num
            self.gamma      = nn.Parameter( torch.randn(c_num, 1, 1)    )
            self.beta       = nn.Parameter( torch.zeros(c_num, 1, 1)    )
            self.eps        = eps
    
        def forward(self, x):
            N, C, H, W  = x.size()
            x           = x.view(   N, self.group_num, -1   )
            mean        = x.mean(   dim = 2, keepdim = True )
            std         = x.std (   dim = 2, keepdim = True )
            x           = (x - mean) / (std+self.eps)
            x           = x.view(N, C, H, W)
            return x * self.gamma + self.beta
    
    
    class SRU(nn.Module):
        def __init__(self,
                     oup_channels:int, 
                     group_num:int = 16,
                     gate_treshold:float = 0.5 
                     ):
            super().__init__()
            
            self.gn             = GroupBatchnorm2d( oup_channels, group_num = group_num )
            self.gate_treshold  = gate_treshold
            self.sigomid        = nn.Sigmoid()
    
        def forward(self,x):
            gn_x        = self.gn(x)
            w_gamma     = self.gn.gamma/sum(self.gn.gamma)
            reweigts    = self.sigomid( gn_x * w_gamma )
            # Gate
            info_mask   = reweigts>=self.gate_treshold
            noninfo_mask= reweigts<self.gate_treshold
            x_1         = info_mask * x
            x_2         = noninfo_mask * x
            x           = self.reconstruct(x_1,x_2)
            return x
        
        def reconstruct(self,x_1,x_2):
            x_11,x_12 = torch.split(x_1, x_1.size(1)//2, dim=1)
            x_21,x_22 = torch.split(x_2, x_2.size(1)//2, dim=1)
            return torch.cat([ x_11+x_22, x_12+x_21 ],dim=1)
    
    
    class CRU(nn.Module):
        '''
        alpha: 0<alpha<1
        '''
        def __init__(self, 
                     op_channel:int,
                     alpha:float = 1/2,
                     squeeze_radio:int = 2 ,
                     group_size:int = 2,
                     group_kernel_size:int = 3,
                     ):
            super().__init__()
            self.up_channel     = up_channel   =   int(alpha*op_channel)
            self.low_channel    = low_channel  =   op_channel-up_channel
            self.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)
            self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)
            #up
            self.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)
            self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)
            #low
            self.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)
            self.advavg         = nn.AdaptiveAvgPool2d(1)
    
        def forward(self,x):
            # Split
            up,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)
            up,low  = self.squeeze1(up),self.squeeze2(low)
            # Transform
            Y1      = self.GWC(up) + self.PWC1(up)
            Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )
            # Fuse
            out     = torch.cat( [Y1,Y2], dim= 1 )
            out     = F.softmax( self.advavg(out), dim=1 ) * out
            out1,out2 = torch.split(out,out.size(1)//2,dim=1)
            return out1+out2
    
    
    class ScConv(nn.Module):
        def __init__(self,
                    op_channel:int,
                    group_num:int = 16,
                    gate_treshold:float = 0.5,
                    alpha:float = 1/2,
                    squeeze_radio:int = 2 ,
                    group_size:int = 2,
                    group_kernel_size:int = 3,
                     ):
            super().__init__()
            self.SRU = SRU( op_channel, 
                           group_num            = group_num,  
                           gate_treshold        = gate_treshold )
            self.CRU = CRU( op_channel, 
                           alpha                = alpha, 
                           squeeze_radio        = squeeze_radio ,
                           group_size           = group_size ,
                           group_kernel_size    = group_kernel_size )
        
        def forward(self,x):
            x = self.SRU(x)
            x = self.CRU(x)
            return x
    
    
    if __name__ == '__main__':
        x       = torch.randn(1,32,16,16)
        model   = ScConv(32)
        print(model(x).shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
  • 相关阅读:
    ABS10-ASEMI开关电源整流桥ABS10
    大数据管道聚合并分页 有什么调优方案
    Chapter5:Additional Control Information
    万兆光模块对网络性能的提升有多大?
    matplotlib基础加进阶
    T1175计算两个日期之间的天数
    图论|207. 课程表 210. 课程表 II
    第二部分:DDD 设计中的基本元素
    [C++] C/C++内存管理、
    2023NOIP A层联测17-爆炸
  • 原文地址:https://blog.csdn.net/DM_zx/article/details/132701071