• pytorch代码实现之分布偏移卷积DSConv


    DSConv卷积模块

    DSConv(分布偏移卷积)可以容易地替换进标准神经网络体系结构并且实现较低的存储器使用和较高的计算速度。 DSConv将传统的卷积内核分解为两个组件:可变量化内核(VQK)和分布偏移。 通过在VQK中仅存储整数值来实现较低的存储器使用和较高的速度,同时通过应用基于内核和基于通道的分布偏移来保持与原始卷积相同的输出。 我们在ResNet50和34以及AlexNet和MobileNet上对ImageNet数据集测试了DSConv。 我们通过将浮点运算替换为整数运算,在卷积内核中实现了高达14x的内存使用量减少,并将运算速度提高了10倍。 此外,与其他量化方法不同,我们的工作允许对新任务和数据集进行一定程度的再训练。

    原文地址:DSConv: Efficient Convolution Operator

    DSConv结构图

    代码实现:

    import torch.nn.functional as F
    from torch.nn.modules.conv import _ConvNd
    from torch.nn.modules.utils import _pair
    
    class DSConv(_ConvNd):
        def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=None, dilation=1, groups=1, padding_mode='zeros', bias=False, block_size=32, KDSBias=False, CDS=False):
            padding = _pair(autopad(kernel_size, padding))
            kernel_size = _pair(kernel_size)
            stride = _pair(stride)
            dilation = _pair(dilation)
    
            blck_numb = math.ceil(((in_channels)/(block_size*groups)))
            super(DSConv, self).__init__(
                in_channels, out_channels, kernel_size, stride, padding, dilation,
                False, _pair(0), groups, bias, padding_mode)
    
            # KDS weight From Paper
            self.intweight = torch.Tensor(out_channels, in_channels, *kernel_size)
            self.alpha = torch.Tensor(out_channels, blck_numb, *kernel_size)
    
            # KDS bias From Paper
            self.KDSBias = KDSBias
            self.CDS = CDS
    
            if KDSBias:
                self.KDSb = torch.Tensor(out_channels, blck_numb, *kernel_size)
            if CDS:
                self.CDSw = torch.Tensor(out_channels)
                self.CDSb = torch.Tensor(out_channels)
    
            self.reset_parameters()
    
        def get_weight_res(self):
            # Include expansion of alpha and multiplication with weights to include in the convolution layer here
            alpha_res = torch.zeros(self.weight.shape).to(self.alpha.device)
    
            # Include KDSBias
            if self.KDSBias:
                KDSBias_res = torch.zeros(self.weight.shape).to(self.alpha.device)
    
            # Handy definitions:
            nmb_blocks = self.alpha.shape[1]
            total_depth = self.weight.shape[1]
            bs = total_depth//nmb_blocks
    
            llb = total_depth-(nmb_blocks-1)*bs
    
            # Casting the Alpha values as same tensor shape as weight
            for i in range(nmb_blocks):
                length_blk = llb if i==nmb_blocks-1 else bs
    
                shp = self.alpha.shape # Notice this is the same shape for the bias as well
                to_repeat=self.alpha[:, i, ...].view(shp[0],1,shp[2],shp[3]).clone()
                repeated = to_repeat.expand(shp[0], length_blk, shp[2], shp[3]).clone()
                alpha_res[:, i*bs:(i*bs+length_blk), ...] = repeated.clone()
    
                if self.KDSBias:
                    to_repeat = self.KDSb[:, i, ...].view(shp[0], 1, shp[2], shp[3]).clone()
                    repeated = to_repeat.expand(shp[0], length_blk, shp[2], shp[3]).clone()
                    KDSBias_res[:, i*bs:(i*bs+length_blk), ...] = repeated.clone()
    
            if self.CDS:
                to_repeat = self.CDSw.view(-1, 1, 1, 1)
                repeated = to_repeat.expand_as(self.weight)
                print(repeated.shape)
    
            # Element-wise multiplication of alpha and weight
            weight_res = torch.mul(alpha_res, self.weight)
            if self.KDSBias:
                weight_res = torch.add(weight_res, KDSBias_res)
            return weight_res
    
        def forward(self, input):
            # Get resulting weight
            #weight_res = self.get_weight_res()
    
            # Returning convolution
            return F.conv2d(input, self.weight, self.bias,
                                self.stride, self.padding, self.dilation,
                                self.groups)
    
    class DSConv2D(Conv):
        def __init__(self, inc, ouc, k=1, s=1, p=None, g=1, act=True):
            super().__init__(inc, ouc, k, s, p, g, act)
            self.conv = DSConv(inc, ouc, k, s, p, g)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
  • 相关阅读:
    免费的MySQL连接工具
    搭建WAMP网站教程(Windows+Apache+MySQL+PHP)
    Javascript——闭包
    @RequestBody 和 @RequestParam注解使用详解
    有关git commit --amend的用法及若干个问题
    第34讲:MySQL中常用的几种存储引擎以及如何选择
    抖去推短视频矩阵系统----源头开发
    A7.2022年全国数学建模竞赛A题-波浪能最大输出功率设计-赛题分析与讨论
    图论进阶之路-最小生成树模版
    02Linux各目录及每个目录的详细介绍
  • 原文地址:https://blog.csdn.net/DM_zx/article/details/132857249