• PyTorch 结构重参数化 RepVGGBlock


    在 ShuffleNet v2 中提出了轻量化网络的 4 大设计准则:

    • 输入输出通道相同时,MAC 最小
    • FLOPs 相同时,分组数过大的分组卷积会增加 MAC
    • 碎片化操作 (多分支结构) 对并行加速不友好
    • 逐元素操作带来的内存和耗时不可忽略

    近年来,卷积神经网络的结构已经变得越来越复杂;得益于多分支结构良好的收敛能力,多分支结构越来越流行

    但是,使用多分支结构的时候,一方面无法有效地利用并行加速,另一方面增加了 MAC

    为了使简单结构也能达到与多分支结构相当的精度,在训练 RepVGG 时使用多分支结构 (3×3 卷积 + 1×1 卷积 + 恒等映射),以借助其良好的收敛能力;在推理、部署时利用重参数化技术将多分支结构转化为单路结构,以借助简单结构极致的速度

    重参数化

    训练所使用的多分支结构中,每一个分支中均有一个 BN 层

    BN 层有四个运算时使用的参数:mean、var、weight、bias,对输入 x 执行以下变换:

    BN(x)=weight \cdot \frac{x-mean}{\sqrt{var}}+bias

    转化为 BN(x) = w_{bn} \cdot x +b_{bn} 的形式时:

    w_{bn}=\frac{weight}{\sqrt{var}},\ b_{bn}=bias-\frac{weight\cdot mean}{\sqrt{var}}

    1. import torch
    2. from torch import nn
    3. class BatchNorm(nn.BatchNorm2d):
    4. def unpack(self):
    5. mean, weight, bias = self.running_mean, self.weight, self.bias
    6. std = (self.running_var + self.eps).sqrt()
    7. eq_weight = weight / std
    8. eq_bias = bias - weight * mean / std
    9. return eq_weight, eq_bias
    10. bn = BatchNorm(8).eval()
    11. # 初始化随机参数
    12. bn.running_mean.data, bn.running_var.data, bn.weight.data, bn.bias.data = torch.rand([4, 8])
    13. image = torch.rand([1, 8, 1, 1])
    14. print(bn(image).view(-1))
    15. # 将 BN 的参数转化为 w, b 形式
    16. weight, bias = bn.unpack()
    17. print(image.view(-1) * weight + bias)

    因为 BN 层会拟合每一个通道的偏置,所以将卷积层和 BN 层连接在一起使用时,卷积层不使用偏置,其运算可以表示为:

    Conv(x)=w_{c}*x

    BN(Conv(x))=w_{bn}w_{c}*x+b_{bn}

    可见,卷积层和 BN 层可以等价于一个带偏置的卷积层

    而恒等映射亦可等价于 1×1 卷积:

    • 对于 nn.Conv2d(c1, c2, kernel_size=1),其参数的 shape 为 [c2, c1, 1, 1] —— 可看作 [c2, c1] 的线性层,以执行各个像素点的通道变换 (参考:PyTorch 二维多通道卷积运算方式)
    • 当 c1 = c2、且这个线性层为单位阵时,等价于恒等映射

    1×1 卷积又可通过填充 0 表示成 3×3 卷积,所以该多分支结构的计算可表示为:

    BN_{3 \times 3}(Conv_{3 \times 3}(x))=w_3*x+b_3

    BN_{1 \times 1}(Conv_{1 \times 1}(x))=w_1*x+b_1

    BN_{id}(Conv_{id}(x))=w_o*x+b_0

    y=(w_3+w_1+w_0)*x+(b_3+b_1+b_0)

    从而可以等价成一个新的 3×3 卷积 (该结论亦可推广到分组卷积、5×5 卷积)

    在 NVIDIA 1080Ti 上进行速度测试,以 [32, 2048, 56, 56] 的图像输入卷积核得到同通道同尺寸的输出,3×3 卷积每秒浮点运算量最多

    结构复现

    参考代码:https://github.com/DingXiaoH/RepVGG

    我对论文中的源代码进行了重构,目的是增强其可读性、易用性 (为了可移植进 YOLO 项目,去除了 L2 范数的计算)

    同时,我也将重参数化的函数写入类的静态方法,支持集成模型的重参数化

    1. from collections import OrderedDict
    2. import torch
    3. import torch.nn.functional as F
    4. from torch import nn
    5. class BatchNorm(nn.BatchNorm2d):
    6. def unpack(self):
    7. mean, weight, bias = self.running_mean, self.weight, self.bias
    8. std = (self.running_var + self.eps).sqrt()
    9. eq_weight = weight / std
    10. eq_bias = bias - weight * mean / std
    11. return eq_weight, eq_bias
    12. class RepVGGBlock(nn.Module):
    13. def __init__(self, c1, c2, k=3, s=1, g=1, deploy=False):
    14. super(RepVGGBlock, self).__init__()
    15. self.deploy = deploy
    16. # 校对卷积核的尺寸
    17. assert k & 1, 'The convolution kernel size must be odd'
    18. # 主分支卷积参数
    19. self.conv_main_config = dict(
    20. in_channels=c1, out_channels=c2, kernel_size=k,
    21. stride=s, padding=k // 2, groups=g
    22. )
    23. if deploy:
    24. self.conv_main = nn.Conv2d(**self.conv_main_config, bias=True)
    25. else:
    26. # 主分支
    27. self.conv_main = nn.Sequential(OrderedDict(
    28. conv=nn.Conv2d(**self.conv_main_config, bias=False),
    29. bn=BatchNorm(c2)
    30. ))
    31. # 1×1 卷积分支
    32. self.conv_1x1 = nn.Sequential(OrderedDict(
    33. conv=nn.Conv2d(c1, c2, 1, s, padding=0, groups=g, bias=False),
    34. bn=BatchNorm(c2)
    35. )) if k != 1 else None
    36. # 恒等映射分支
    37. self.identity = BatchNorm(c2) if c1 == c2 and s == 1 else None
    38. def forward(self, x, act=F.silu):
    39. y = self.conv_main(x)
    40. if self.conv_1x1:
    41. y += self.conv_1x1(x)
    42. if self.identity:
    43. y += self.identity(x)
    44. # 使用激活函数
    45. y = act(y) if act else y
    46. return y
    47. @staticmethod
    48. def merge(model: nn.Module):
    49. # 查询模型的所有子模型, 对 RepVGGBlock 进行合并
    50. for m in model.modules():
    51. if isinstance(m, RepVGGBlock) and not m.deploy:
    52. # 主分支的信息
    53. kernel = m.conv_main.conv.weight
    54. (c2, c1_per_group, k, _), g = kernel.shape, m.conv_main.conv.groups
    55. center_pos = k // 2
    56. # 转换主分支
    57. bn_weight, bn_bias = m.conv_main.bn.unpack()
    58. kernel_weight, kernel_bias = kernel * bn_weight.view(-1, 1, 1, 1), bn_bias
    59. # 转换 1×1 卷积分支
    60. if m.conv_1x1:
    61. kernel_1x1 = m.conv_1x1.conv.weight[..., 0, 0]
    62. bn_weight, bn_bias = m.conv_1x1.bn.unpack()
    63. kernel_weight[..., center_pos, center_pos] += kernel_1x1 * bn_weight.view(-1, 1)
    64. kernel_bias += bn_bias
    65. # 转换恒等映射分支
    66. if m.identity:
    67. kernel_id = torch.cat([torch.eye(c1_per_group)] * g, dim=0).to(kernel.device)
    68. bn_weight, bn_bias = m.identity.unpack()
    69. kernel_weight[..., center_pos, center_pos] += kernel_id * bn_weight.view(-1, 1)
    70. kernel_bias += bn_bias
    71. # 声明合并后的卷积核
    72. m.conv_main = nn.Conv2d(**m.conv_main_config, bias=True)
    73. m.conv_main.weight.data, m.conv_main.bias.data = kernel_weight, kernel_bias
    74. # 删除被合并的分支
    75. m.deploy = True
    76. delattr(m, 'conv_1x1')
    77. delattr(m, 'identity')
    78. m.conv_1x1, m.identity = None, None

    然后设计一个集成模型进行验证:

    • merge 函数是否改变了网络结构
    • 重参数化前后,模型的运算结果是否一致
    • 重参数化后,模型的推理速度是否有所提升
    1. if __name__ == '__main__':
    2. class RepVGG(nn.Module):
    3. def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, deploy=False):
    4. super(RepVGG, self).__init__()
    5. assert len(width_multiplier) == 4
    6. self.deploy = deploy
    7. # 输入通道数
    8. self.in_planes = min(64, int(64 * width_multiplier[0]))
    9. self.stage0 = RepVGGBlock(3, self.in_planes, k=3, s=2, deploy=self.deploy)
    10. # 主干部分分为四部分, 每一部分使用多个 RepVGGBlock 级联
    11. self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)
    12. self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)
    13. self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)
    14. self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)
    15. self.gap = nn.AdaptiveAvgPool2d(output_size=1)
    16. self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
    17. def _make_stage(self, planes, num_blocks, stride):
    18. strides = [stride] + [1] * (num_blocks - 1)
    19. blocks = []
    20. for stride in strides:
    21. blocks.append(RepVGGBlock(self.in_planes, planes, k=3, s=stride, deploy=self.deploy))
    22. self.in_planes = planes
    23. return nn.Sequential(*blocks)
    24. def forward(self, x):
    25. out = self.stage0(x)
    26. out = self.stage1(out)
    27. out = self.stage2(out)
    28. out = self.stage3(out)
    29. out = self.stage4(out)
    30. out = self.gap(out)
    31. out = out.view(out.size(0), -1)
    32. out = self.linear(out)
    33. return out
    34. vgg = RepVGG(num_blocks=[1, 1, 1, 1], num_classes=20,
    35. width_multiplier=[1, 1, 1, 1]).eval()
    36. print(vgg)
    37. # 为 BatchNorm 初始化随机参数
    38. for m in vgg.modules():
    39. if isinstance(m, BatchNorm):
    40. m.running_mean.data, m.running_var.data, \
    41. m.weight.data, m.bias.data = torch.rand([4, m.num_features])
    42. image = torch.rand([1, 3, 224, 224])
    43. class Timer:
    44. prefix = 'Cost: '
    45. def __init__(self, fun, *args, **kwargs):
    46. import time
    47. start = time.time()
    48. fun(*args, **kwargs)
    49. cost = (time.time() - start) * 1e3
    50. print(self.prefix + f'{cost:.0f} ms')
    51. # 使用训练结构的 VGG 进行测试
    52. print(vgg(image))
    53. Timer(vgg, image)
    54. # 调用 RepVGGBlock 的静态方法, 合并 RepVGGBlock 的分支
    55. RepVGGBlock.merge(vgg)
    56. print(vgg)
    57. # 使用推理结构的 VGG 进行测试
    58. print(vgg(image))
    59. Timer(vgg, image)

    RepVGG(
      (stage0): RepVGGBlock(
        (conv_main): Sequential(
          (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (conv_1x1): Sequential(
          (conv): Conv2d(3, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (stage1): Sequential(
        (0): RepVGGBlock(
          (conv_main): Sequential(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (bn): BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv_1x1): Sequential(
            (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (stage2): Sequential(
        (0): RepVGGBlock(
          (conv_main): Sequential(
            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (bn): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv_1x1): Sequential(
            (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (stage3): Sequential(
        (0): RepVGGBlock(
          (conv_main): Sequential(
            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (bn): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv_1x1): Sequential(
            (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (stage4): Sequential(
        (0): RepVGGBlock(
          (conv_main): Sequential(
            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (bn): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv_1x1): Sequential(
            (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
      )
      (gap): AdaptiveAvgPool2d(output_size=1)
      (linear): Linear(in_features=512, out_features=20, bias=True)
    )
    tensor([[-0.1108,  0.0824,  0.5547, -0.1671,  0.7442, -0.1164, -0.2825,  0.4088,
              0.1239, -0.3792,  0.1152, -0.4021,  0.4034,  0.2350,  0.2601, -0.1197,
              0.2462, -0.2451,  0.0439, -0.2507]], grad_fn=)
    Cost: 22 ms


    RepVGG(
      (stage0): RepVGGBlock(
        (conv_main): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (stage1): Sequential(
        (0): RepVGGBlock(
          (conv_main): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (stage2): Sequential(
        (0): RepVGGBlock(
          (conv_main): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (stage3): Sequential(
        (0): RepVGGBlock(
          (conv_main): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (stage4): Sequential(
        (0): RepVGGBlock(
          (conv_main): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        )
      )
      (gap): AdaptiveAvgPool2d(output_size=1)
      (linear): Linear(in_features=512, out_features=20, bias=True)
    )
    tensor([[-0.1108,  0.0824,  0.5547, -0.1671,  0.7442, -0.1164, -0.2825,  0.4088,
              0.1239, -0.3792,  0.1152, -0.4021,  0.4034,  0.2350,  0.2601, -0.1197,
              0.2462, -0.2451,  0.0439, -0.2507]], grad_fn=)
    Cost: 14 ms

  • 相关阅读:
    Chiplet技术与汽车芯片(二)
    “兼职开发的程序员,为什么不受企业待见?”
    Java 热更新 Groovy 实践及踩坑指南
    node开发微信群聊机器人第④章
    Spring-Security前后端分离权限认证
    C语言学习:4、C语言的运算
    Python 引用不确定的函数
    前端稳定性建设
    Mysql存储json格式数据
    JAVA安全之Log4j-Jndi注入原理以及利用方式
  • 原文地址:https://blog.csdn.net/qq_55745968/article/details/125887670