• 【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝


    最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论


    前言

    深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术,通过去除神经网络中的冗余连接(权重)或节点(神经元),从而实现模型的稀疏化。
    深度学习剪枝(Pruning)具有以下几个好处:1. 模型压缩和存储节省;2. 计算资源节省;3. 加速推理速度;4. 防止过拟合。
    “假剪枝”(Fake Pruning)是一种剪枝算法的称呼,它在剪枝过程中并不真正删除权重或节点,而是通过一些技巧将它们置零或禁用,以模拟剪枝的效果,不少优秀的论文就采用了"假剪枝"策略,尽管可以在一定程度上提高模型的推理速度,但假剪枝算法没有真正减少模型的大小,博主将通过讲解一个小案例,简洁易懂的说明一种对"假剪枝"卷积层进行真正的剪枝的的方法。


    卷积层剪枝

    可以先将最后的完整代码拷贝到自己的py文件中,然后按照博主的思路学习如何将置零卷积核进行真实剪枝:

    1. 初始化卷积层,并查看卷积层权重
      # 示例使用一个具有3个输入通道和5个输出通道的卷积层
      conv = nn.Conv2d(3, 5, 3)
      print("原始卷积层权重:")
      print(conv.weight.data)
      print(conv.weight.size())
      print("原始卷积层偏置:")
      print(conv.bias.data)
      print(conv.bias.size())
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
    2. 通过随机函数让部分卷积核权重置为0,模拟完成了假剪枝。
      # remove_zero_kernels方法内的代码
      weight = conv_layer.weight.data
      # 卷积核个数
      num_kernels = weight.size(0)
      # 随机对部分卷积置0
      pruned = torch.ones(num_kernels, 1, 1, 1)
      # 选择随着置0的卷积序号
      random_int = random.randint(1, num_kernels-1)
      for i in range(random_int):
          pruned[i, 0, 0, 0] = 0
      conv_layer.weight.data = weight * pruned
      weight = conv_layer.weight.data
      bias = conv_layer.bias.data
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
    3. 保存未被剪枝的卷积核的权重和偏置
      # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
      norms = torch.norm(weight.view(num_kernels, -1), dim=1)
      zero_kernel_indices = torch.nonzero(norms==0).squeeze()
      print(zero_kernel_indices)
      # 移除L2范数为零的卷积核
      new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
      new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
    4. 构建新的卷积层,用来替换此前的卷积层,完成置零卷积核的真实剪枝
      # 构建新的卷积层
      if zero_kernel_indices.numel() > 0:
          # 输入channel
          in_channels = weight.size(1)
          # 输出channel
          out_channels = new_weight.size(0)
          # 卷积核大小
          kernel_size = weight.size(2)
          # 步长
          stride = conv_layer.stride
          padding = conv_layer.padding
          dilation = conv_layer.dilation
          groups = conv_layer.groups
          new_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
          new_conv_layer.weight.data = new_weight
          new_conv_layer.bias.data = new_bias
      else:
          new_conv_layer = conv_layer
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18

    完整代码

    import torch
    import torch.nn as nn
    import random
    
    def remove_zero_kernels(conv_layer):
        # 卷积核权重
        weight = conv_layer.weight.data
        # 卷积核个数
        num_kernels = weight.size(0)
        # 随机对部分卷积置0
        pruned = torch.ones(num_kernels, 1, 1, 1)
        # 选择随着置0的卷积序号
        random_int = random.randint(1, num_kernels-1)
        for i in range(random_int):
            pruned[i, 0, 0, 0] = 0
        conv_layer.weight.data = weight * pruned
        weight = conv_layer.weight.data
        bias = conv_layer.bias.data
        # 计算每个卷积核的L2范数,目的是为了检查卷积核的所有位置是不是都置0了
        norms = torch.norm(weight.view(num_kernels, -1), dim=1)
        zero_kernel_indices = torch.nonzero(norms==0).squeeze()
        print(zero_kernel_indices)
        # 移除L2范数为零的卷积核
        new_weight = torch.stack([weight[i, :, :, :] for i in range(num_kernels) if i not in zero_kernel_indices])
        new_bias = torch.stack([bias[i] for i in range(num_kernels) if i not in zero_kernel_indices])
        # 构建新的卷积层
        if zero_kernel_indices.numel() > 0:
            # 输入channel
            in_channels = weight.size(1)
            # 输出channel
            out_channels = new_weight.size(0)
            # 卷积核大小
            kernel_size = weight.size(2)
            # 步长
            stride = conv_layer.stride
            padding = conv_layer.padding
            dilation = conv_layer.dilation
            groups = conv_layer.groups
            new_conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups)
            new_conv_layer.weight.data = new_weight
            new_conv_layer.bias.data = new_bias
        else:
            new_conv_layer = conv_layer
    
        return new_conv_layer
    
    # 示例使用一个具有3个输入通道和5个输出通道的卷积层
    conv = nn.Conv2d(3, 5, 3)
    # print("原始卷积层权重:")
    # print(conv.weight.data)
    # print(conv.weight.size())
    # print("原始卷积层偏置:")
    # print(conv.bias.data)
    # print(conv.bias.size())
    
    # 将置零的卷积核移除
    new_conv = remove_zero_kernels(conv)
    # print("原始卷积层权重:")
    # print(new_conv.weight.data)
    # print(new_conv.weight.size())
    # print("原始卷积层偏置:")
    # print(new_conv.bias.data)
    # print(new_conv.bias.size())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63

    总结

    博主的思路就是用卷积层中保留的(未被剪枝)权重初始化一个新的卷积层,这样就将假剪枝的置零卷积核真实的除去,有没有研究这方面的读者可以给博主分享其他的方法,共同进步。

  • 相关阅读:
    Flink CDC MySQL同步MySQL错误记录
    紫光同创PG2L100H关键特性评估板,盘古100K开发板,可实现复杂项目的开发
    Springboot面向会员体系的电商平台an5y9计算机毕业设计-课程设计-期末作业-毕设程序代做
    计算机网络:随机访问介质访问控制之CSMA协议
    ts3.接口和对象类型
    【HTML】制作一个简单的三角形动态图形
    定义一个结构体变量(包括年月日)。计算该日在 本年中是第几天?注意闰年问题。
    8、动态SQL
    LQ0016 九进制转十进制【进制】
    快速排序与归并排序的链式实现(golang)
  • 原文地址:https://blog.csdn.net/yangyu0515/article/details/134185435