• 【Pytorch】一文搞懂nn.Conv2d的groups参数的作用


    1. 语言描述

    在Pytorch1.13的官方文档中,关于nn.Conv2d中的groups的作用是这么描述的:
    在这里插入图片描述
    简单来说就是将输入和输出的通道(channel)进行分组,每一组单独进行卷积操作,然后再把结果拼接(concat)起来。

    比如输入大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5),输出大小为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5) g r o u p s = 2 groups=2 groups=2。就是将输入的4个channel分成2个2的channel,输出的8个channel分成2个4的channel,每个输入的2个channel和输出的4个channel组成一组,每组做完卷积后的输出大小为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)。然后把得到的两组输出在channel这个维度上进行concat,得到最后的输出维度为 ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)

    但其实这么描述理解起来不够直观,下面我举个例子,先从语言上进行详细的解释,然后再进行代码验证。

    符号数值含义
    i n p u t _ c h a n n e l input\_channel input_channel4输入通道数量
    o n p u t _ c h a n n e l onput\_channel onput_channel8输出通道数量,其实就是卷积核的个数,我们将其看作卷积核的个数会更容易理解
    b a t c h _ s i z e batch\_size batch_size1批量大小为1
    H , W H, W H,W5输入输出的feature大小为5x5
    i n p u t _ s h a p e input\_shape input_shape ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输入的shape,注意我们这里设置输入的所有元素都为1,即输入是一个全1的tensor
    o u t p u t _ s h a p e output\_shape output_shape ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5)输出的shape
    k e r n e l _ s i z e kernel\_size kernel_size3卷积核的大小为3x3
    p a d d i n g padding padding1填充长度为1,这里我们使用1填充(即周围补一圈1),而不是0填充
    s t r i d e stride stride1步长为1

    我们假设输入tensor的shape为 ( 1 , 4 , 5 , 5 ) (1, 4, 5, 5) (1,4,5,5)输出tensor的shape为: ( 1 , 8 , 5 , 5 ) (1, 8, 5, 5) (1,8,5,5),即我们的卷积核有8个。下面的图由于 b a t c h _ s i z e = 1 batch\_size=1 batch_size=1,所以省略的 b a t c h _ s i z e batch\_size batch_size的维度。
    在这里插入图片描述

    值得注意的是,这里我们手动设置卷积核中元素的值,前4个卷积核的值都设置为1,后4个卷积核的值都设置为2,如下图所示:

    在这里插入图片描述
    这里解释一下为什么 g r o u p s = 1 groups=1 groups=1 k e r n e l _ s i z e = ( 4 , 3 , 3 ) kernel\_size=(4, 3, 3) kernel_size=(4,3,3) g r o u p s = 2 groups=2 groups=2 k e r n e l _ s i z e = ( 2 , 3 , 3 ) kernel\_size=(2, 3, 3) kernel_size=(2,3,3):因为 g r o u p s = 2 groups=2 groups=2时,输入和输出都被分成了两组,输入的shape原来为: ( 4 , 5 , 5 ) (4, 5, 5) (4,5,5),被分成了两个 ( 2 , 5 , 5 ) (2, 5, 5) (2,5,5),所以每个 k e r n e l _ s i z e kernel\_size kernel_size也由 ( 4 , 3 , 3 ) (4, 3, 3) (4,3,3)变为 ( 2 , 3 , 3 ) (2, 3, 3) (2,3,3)

    下面我们来看一下 g r o u p s = 1 groups=1 groups=1 g r o u p s = 2 groups=2 groups=2时计算过程的不同:

    【情况1:groups=1】
    此时就和正常卷积一样:
    在这里插入图片描述

    这里解释一下:output的前4个channel的每个feature map的所有元素都为36,后4个channel的每个feature map的所有元素都为72,这是因为:
    每个输入的 H , W H,W H,W是5x5,加上padding之后是6x6,具体过程如下:
    在这里插入图片描述

    【情况1:groups=2】
    此时应当这么算:
    在这里插入图片描述
    为什么output的前4个channel的每个feature map的所有元素都为18,后4个channel的每个feature map的所有元素都为36呢?看了下面的图应该就能理解这个过程了:
    在这里插入图片描述

    2. 代码验证:

    实验环境:Python3.7,torch1.10.2
    代码:

    import os
    
    import torch
    import torch.nn as nn
    
    
    if __name__ == '__main__':
        input_dim, output_dim = 4, 8
        X = torch.ones(1, input_dim, 5, 5)
    
        # groups = 1
        conv1 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=1, bias=False, padding_mode='replicate')
        print(f'groups=1时,卷积核的形状为:{conv1.weight.shape}')
        with torch.no_grad():
            conv1.weight[:4, :, :, :] = torch.ones(4, 4, 3, 3)
            conv1.weight[4:, :, :, :] = torch.ones(4, 4, 3, 3) * 2
            Y1 = conv1(X)
            print(f'结果为:\n{Y1}')
    
        # groups = 2
        conv2 = nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1, groups=2, bias=False, padding_mode='replicate')
        print(f'groups=2时,卷积核的形状为:{conv2.weight.shape}')
        with torch.no_grad():
            conv2.weight[:4, :, :, :] = torch.ones(4, 2, 3, 3)
            conv2.weight[4:, :, :, :] = torch.ones(4, 2, 3, 3) * 2
            Y2 = conv2(X)
            print(f'结果为:\n{Y2}')
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    结果:

    groups=1时,卷积核的形状为:torch.Size([8, 4, 3, 3])
    结果为:
    tensor([[[[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.]],
    
             [[72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.]],
    
             [[72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.]],
    
             [[72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.],
              [72., 72., 72., 72., 72.]]]])
    groups=2时,卷积核的形状为:torch.Size([8, 2, 3, 3])
    结果为:
    tensor([[[[18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.]],
    
             [[18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.]],
    
             [[18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.]],
    
             [[18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.],
              [18., 18., 18., 18., 18.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]],
    
             [[36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.],
              [36., 36., 36., 36., 36.]]]])
    
    Process finished with exit code 0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100

    整体流程我手画了个图,我感觉比PPT画的还清楚,可以更好地理解过程在这里插入图片描述

    END:)

    p.s.:没想到写个博客写了一上午,画图太费时间了!本来上午还有别的事情的。。。只能推到下午再做了0.0
    
    • 1
  • 相关阅读:
    《Grid Tagging Scheme for Aspect-oriented Fine-grained Opinion Extraction》论文阅读
    我终于读懂了适配器模式。。。
    哪个产品功能重要?KANO模型帮你
    常带电电路,PCB 布局布线注意
    基础篇01——SQL的基本语法和分类
    论文笔记: 极限多标签学习之 FastXML
    条码二维码读取设备在医疗设备自助服务的重要性
    FlutterAcivity 包已导入 但是仍然爆红
    Vue学习第16天——全局事件总线$bus的理解
    〔001〕Java 基础之环境安装和编写首个程序
  • 原文地址:https://blog.csdn.net/qq_44166630/article/details/127802567