• PyTorch 二维多通道卷积运算方式


    全连接卷积

    对于单通道卷积的运算,相信大家已经见过不少了

    那么在卷积神经网络中,图像的通道数是怎样实现“从 3 到 8” 这样的跳变呢?

    接下来以尺寸 5 × 5 的卷积核做一下实验:

    1. import torch
    2. c1, c2, k = 3, 8, 5
    3. # 实例化二维卷积, 不使用 padding
    4. conv = torch.nn.Conv2d(in_channels=c1, out_channels=c2,
    5. kernel_size=k, bias=False)
    6. # 二维卷积的 weight: [c2, c1, k, k]
    7. weight = conv.weight.data
    8. print('Weight shape:', weight.shape)

    使用 PyTorch 的 Conv2d 可以发现,weight 参数的 shape 是 [8, 3, 5, 5]

    [3, 5, 5] 的图像(记为 img)经过该卷积运算后得到 [8, 1, 1],以此为例,提出猜想:

    • img[..., r, c] 的 shape 为 [3, ],即代表该像素点 3 个通道的值
    • weight[..., r, c] 的 shape 为 [8, 3],代表该卷积核的感受野中,第 r 行第 c 列的参数
    • 使用矩阵乘法运算 weight[..., r, c] × img[..., r, c],结果是 shape 为 [8, ] 的张量,代表该像素点 8 个通道的值;如果将所有像素点 8 个通道的值相加,即可得到此次卷积结果

    写出验证的函数如下:

    1. # 测试使用的图像, 卷积后: [c1, k, k] -> [c2, 1, 1]
    2. img = torch.rand([1, c1, k, k])
    3. def torch_conv():
    4. return conv(img).view(-1)
    5. def guess_for_FCconv():
    6. ''' 全连接卷积 (标准卷积)'''
    7. result = torch.zeros(c2)
    8. # 对各个像素点进行运算
    9. for r in range(k):
    10. for c in range(k):
    11. # 对应像素点的各通道值: [c1, ]
    12. pixel = img[..., r, c].view(-1)
    13. # 卷积核中对应像素点的参数: [c2, c1]
    14. linear = weight[..., r, c]
    15. # 该像素点对各个通道的贡献: [c2, c1] × [c1, ] -> [c2, ]
    16. result += linear @ pixel
    17. return result
    18. print('PyTorch:', torch_conv())
    19. print('Guess:', guess_for_FCconv())

    可以看到,PyTorch 的运算结果和我的运算结果是一样的,猜想成立

    Weight shape: torch.Size([8, 3, 5, 5])
    PyTorch: tensor([-0.2874, -0.4310,  0.1660, -0.0021,  0.6042, -0.0716,  0.0821,  0.0322],
           grad_fn=<ViewBackward>)
    Guess: tensor([-0.2874, -0.4310,  0.1660, -0.0021,  0.6042, -0.0716,  0.0821,  0.0322])

    提出这个运算方式,是为了帮助大家更好地理解卷积

    在实际的部署中,是不是使用这样的方式运算我就不清楚了 

    深度可分离卷积

    当输入通道数为 4,输出通道数为 8 时,设置卷积核组数为 2

    则 weight 参数的 shape 为 [8, 2, 5, 5],亦可表示成 [8, 4/2, 5, 5]

    1. import torch
    2. c1, c2, k, g = 4, 8, 5, 2
    3. # 实例化二维卷积, 不使用 padding
    4. conv = torch.nn.Conv2d(in_channels=c1, out_channels=c2,
    5. kernel_size=k, groups=g, bias=False)
    6. # 二维卷积的 weight: [c2, c1/g, k, k]
    7. weight = conv.weight.data
    8. print('Weight shape:', weight.shape)
    9. # 测试使用的图像, 卷积后: [c1, k, k] -> [c2, 1, 1]
    10. img = torch.rand([1, c1, k, k])

    输入通道数被进行了分组,那么图像在运算时,在通道上必被分离

    而输出通道数没有进行分组,图像在运算中是否被分离呢?

    [4, 5, 5] 的图像(记为 img)经过该卷积运算后得到 [8, 1, 1],以此为例,提出猜想:

    • 对图像的通道数进行分组:img 可表示成 [2, 2, 5, 5],即 2 张 [2, 5, 5] 的图像
    • 对 weight 的输出通道数进行分组:weight 可表示成 [2, 4, 2, 5, 5],即 2 个 weight 为 [4, 2, 5, 5] 的全连接卷积
    • 分别使用 [4, 2, 5, 5] 的全连接卷积可得到 2 张 [4, 1, 1],拼接在一起后得到 [8, 1, 1]

     写出验证的函数如下:

    1. # 测试使用的图像, 卷积后: [c1, k, k] -> [c2, 1, 1]
    2. img = torch.rand([1, c1, k, k])
    3. def torch_conv():
    4. return conv(img).view(-1)
    5. def guess_for_DWConv():
    6. ''' 深度可分离卷积'''
    7. # 将 c2 个通道表示成 g × c2/g
    8. result = torch.zeros([g, c2 // g])
    9. # 对图像的通道进行分组: [1, c1, k, k] -> [g, c1/g, k, k]
    10. img_ = img.view(g, -1, k, k)
    11. # 分组取出卷积核权值: [c2, c1/g, k, k] -> [g, c2/g, c1/g, k, k]
    12. for i, w in enumerate(weight.view(g, -1, c1//g, k, k)):
    13. # 对各个像素点进行运算
    14. for r in range(k):
    15. for c in range(k):
    16. # 对应像素点的各通道值: [c1/g, ]
    17. pixel = img_[i, :, r, c].view(-1)
    18. # 卷积核中对应像素点的参数: [c2/g, c1/g]
    19. linear = w[..., r, c]
    20. # 该像素点对各个通道的贡献: [c2/g, c1/g] × [c1/g, ] -> [c2/g, ]
    21. result[i] += linear @ pixel
    22. return result.view(-1)
    23. print('PyTorch:', torch_conv())
    24. print('Guess:', guess_for_DWConv())

    显而易见,两个函数的运算结果是一样的,猜想成立

    Weight shape: torch.Size([8, 2, 5, 5])
    PyTorch: tensor([ 0.1674, -0.1527,  0.4059, -0.3422, -0.2362, -0.4508,  0.3286,  0.3232],
           grad_fn=<ViewBackward>)
    Guess: tensor([ 0.1674, -0.1527,  0.4059, -0.3422, -0.2362, -0.4508,  0.3286,  0.3232])

  • 相关阅读:
    Hbase权限访问命令、报错:Grant无权限(acl文件少了)
    统一配置中心Config、Bus组件的使用以及 SpringCloud 微服务工具集总结
    Java 进阶:实例详解 Java 虚拟机字节码指令
    pdf怎么调整大小kb?pdf文件过大这样压缩
    【Java I/O 流】数据输入输出流:DataInputStream 和 DataOutputStream
    杰理之CMD_SET_BT_ADDR【篇】
    如何利用播放器节省20%点播成本
    FL Studio20.9.1水果中文版来啦 Win/Mac中文版FL水果萝卜
    postman archive / postman old versions / postman 历史版本下载
    SqlServer命名规范
  • 原文地址:https://blog.csdn.net/qq_55745968/article/details/125443469