• 模型压缩(一)通道剪枝-BN层


    论文:https://arxiv.org/pdf/1708.06519.pdf

    BN层中缩放因子γ与卷积层中的每个通道关联起来。在训练过程中对这些比例因子进行稀疏正则化,以自动识别不重要的通道。缩放因子值较小的通道(橙色)将被修剪(左侧)。剪枝后,获得了紧凑的模型(右侧),然后对其进行微调,以达到与正常训练的全网络相当(甚至更高)的精度。

    BN层原理:

     归一化化后,BN层服从正态分布,当γ,β趋于0时,经过阈值分离,输出为0,与之连接的卷积层输入为0。

    剪枝流程:

     

    剪枝原理:

    在BN层网络中加入稀疏因子,训练使得BN层稀疏化,对稀疏训练的后的模型中所有BN层权重进行统计排序,获取指定保留BN层数量即取得排序后权重阈值thres。遍历模型中的BN层权重,制作各层mask(权重>thres值为1,权重

    如下实现一个简单的网络剪枝。

    1、自定义一个网络

    对网络进行

    1. import torch
    2. import torch.nn as nn
    3. import numpy as np
    4. class net(nn.Module):
    5. def __init__(self,cfg=None):
    6. super(net, self).__init__()
    7. if cfg:
    8. self.features=self.make_layer(cfg)
    9. self.linear = nn.Linear(cfg[2], 2)
    10. else:
    11. layers=[]
    12. layers+=[nn.Conv2d(3,64,7,2,1,bias=False),
    13. nn.BatchNorm2d(64),
    14. nn.ReLU(inplace=True)]
    15. layers += [
    16. nn.Conv2d(64,128,3,2,1,bias=False),
    17. nn.BatchNorm2d(128),
    18. nn.ReLU(inplace=True)
    19. ]
    20. layers += [
    21. nn.Conv2d(128, 256, 3, 2, 1,bias=False),
    22. nn.BatchNorm2d(256),
    23. nn.ReLU(inplace=True)
    24. ]
    25. layers += [nn.AvgPool2d(2)]
    26. self.features=nn.Sequential(*layers)
    27. self.linear=nn.Linear(256,2)
    28. def make_layer(self,cfg):
    29. layers=[]
    30. layers += [nn.Conv2d(3, cfg[0], 7, 2, 1, bias=False),
    31. nn.BatchNorm2d(cfg[0]),
    32. nn.ReLU(inplace=True)]
    33. layers += [
    34. nn.Conv2d(cfg[0], cfg[1], 3, 2, 1, bias=False),
    35. nn.BatchNorm2d(cfg[1]),
    36. nn.ReLU(inplace=True)
    37. ]
    38. layers += [
    39. nn.Conv2d(cfg[1], cfg[2], 3, 2, 1, bias=False),
    40. nn.BatchNorm2d(cfg[2]),
    41. nn.ReLU(inplace=True)
    42. ]
    43. layers += [nn.AvgPool2d(2)]
    44. return nn.Sequential(*layers)
    45. def forward(self,x):
    46. x=self.features(x)
    47. # print(x.shape)
    48. x=x.view(x.size(0),-1)
    49. x=self.linear(x)
    50. return x

    网络参数信息:

    1. ----------------------------------------------------------------
    2. Layer (type) Output Shape Param #
    3. ================================================================
    4. Conv2d-1 [1, 64, 8, 8] 9,408
    5. BatchNorm2d-2 [1, 64, 8, 8] 128
    6. ReLU-3 [1, 64, 8, 8] 0
    7. Conv2d-4 [1, 128, 4, 4] 73,728
    8. BatchNorm2d-5 [1, 128, 4, 4] 256
    9. ReLU-6 [1, 128, 4, 4] 0
    10. Conv2d-7 [1, 256, 2, 2] 294,912
    11. BatchNorm2d-8 [1, 256, 2, 2] 512
    12. ReLU-9 [1, 256, 2, 2] 0
    13. AvgPool2d-10 [1, 256, 1, 1] 0
    14. Linear-11 [1, 2] 514
    15. ================================================================
    16. Total params: 379,458
    17. Trainable params: 379,458
    18. Non-trainable params: 0
    19. ----------------------------------------------------------------
    20. Input size (MB): 0.00
    21. Forward/backward pass size (MB): 0.17
    22. Params size (MB): 1.45
    23. Estimated Total Size (MB): 1.62

    2、稀疏训练

    在BN层中各权重加入稀疏因子。

    1. def updateBN(model,s=0.0001):
    2. for m in model.modules():
    3. if isinstance(m,nn.BatchNorm2d):
    4. m.weight.grad.data.add_(s*torch.sign(m.weight.data))
    5. if __name__=="__main__":
    6. model=net()
    7. # from torchsummary import summary
    8. # print(summary(model,(3,20,20),1))
    9. # x = torch.rand((1, 3, 20, 20))
    10. # print(model(x))
    11. optimer=torch.optim.Adam(model.parameters())
    12. loss_fn=torch.nn.CrossEntropyLoss()
    13. for e in range(100):
    14. x = torch.rand((1, 3, 20, 20))
    15. y=torch.tensor(np.random.randint(0,2,(1))).long()
    16. out=model(x)
    17. loss=loss_fn(out,y)
    18. optimer.zero_grad()
    19. loss.backward()
    20. #BN权重稀疏化
    21. updateBN(model)
    22. optimer.step()
    23. torch.save(model.state_dict(),"net.pth")

    3、剪枝

    稀疏训练后的模型,解析。

    1. import net
    2. import torch
    3. import torch.nn as nn
    4. import numpy as np
    5. model = net.net()
    6. #加载稀疏训练的模型
    7. model.load_state_dict(torch.load("net.pth"))
    8. total = 0 # 统计所有BN层的参数量
    9. for m in model.modules():
    10. if isinstance(m, nn.BatchNorm2d):
    11. # print(m.weight.data.shape[0]) # 每个BN层权重w参数量:64/128/256
    12. # print(m.weight.data)
    13. total += m.weight.data.shape[0]
    14. print("所有BN层总weight数量:",total)
    15. bn_data=torch.zeros(total)
    16. index=0
    17. for m in model.modules():
    18. #将各个BN层的参数值拷贝到bn中
    19. if isinstance(m,nn.BatchNorm2d):
    20. size=m.weight.data.shape[0]
    21. bn_data[index:(index+size)]=m.weight.data.abs().clone()
    22. index=size
    23. #对bn中的weight值排序
    24. data,id=torch.sort(bn_data)
    25. percent=0.7#保留70%的BN层通道数
    26. thresh_index=int(total*percent)
    27. thresh=data[thresh_index]#取bn排序后的第thresh_index索引值为bn权重的截断阈值
    28. #制作mask
    29. pruned_num=0#统计BN层剪枝通道数
    30. cfg=[]#统计保存通道数
    31. cfg_mask=[]#BN层权重矩阵,剪枝的通道记为0,未剪枝通道记为1
    32. for k,m in enumerate(model.modules()):
    33. if isinstance(m,nn.BatchNorm2d):
    34. weight_copy=m.weight.data.abs().clone()
    35. # print(weight_copy)
    36. mask=weight_copy.gt(thresh).float()#阈值分离权重
    37. # print(mask)
    38. # exit()
    39. pruned_num+=mask.shape[0]-torch.sum(mask)#
    40. # print(pruned_num)
    41. m.weight.data.mul_(mask)#更新BN层的权重,剪枝通道的权重值为0
    42. m.bias.data.mul_(mask)
    43. cfg.append(int(torch.sum(mask)))#记录未被剪枝的通道数量
    44. cfg_mask.append(mask.clone())
    45. print("layer index:{:d}\t total channel:{:d}\t remaining channel:{:d}".format(k,mask.shape[0],int(torch.sum(mask))))
    46. elif isinstance(m,nn.AvgPool2d):
    47. cfg.append("A")
    48. pruned_ratio=pruned_num/total
    49. print("剪枝通道占比:",pruned_ratio)
    50. print(cfg)
    51. newmodel=net.net(cfg)
    52. # print(newmodel)
    53. # from torchsummary import summary
    54. # print(summary(newmodel,(3,20,20),1))
    55. layer_id_in_cfg=0#层
    56. start_mask=torch.ones(3)
    57. end_mask=cfg_mask[layer_id_in_cfg]#第一个BN层对应的mask
    58. # print(cfg_mask)
    59. # print(end_mask)
    60. for(m0,m1)in zip(model.modules(),newmodel.modules()):#以最少的为准
    61. if isinstance(m0,nn.BatchNorm2d):
    62. # idx1=np.squeeze(np.argwhere(np.asarray(end_mask.numpy())))#获得mask中非零索引即未被减掉的序号
    63. # print(idx1)
    64. # exit()
    65. # idx1=np.array([1])
    66. # # print(idx1)
    67. if idx1.size==1:
    68. idx1=np.resize(idx1,(1,))
    69. # print(idx1)
    70. # exit()
    71. #将旧模型的参数值拷贝到新模型中
    72. m1.weight.data=m0.weight.data[idx1.tolist()].clone()
    73. m1.bias.data=m0.bias.data[idx1.tolist()].clone()
    74. m1.running_mean=m0.running_mean[idx1.tolist()].clone()
    75. m1.running_var = m0.running_var[idx1.tolist()].clone()
    76. layer_id_in_cfg+=1#下一个mask
    77. start_mask=end_mask.clone()
    78. if layer_id_in_cfg<len(cfg_mask):
    79. end_mask=cfg_mask[layer_id_in_cfg]
    80. elif isinstance(m0,nn.Conv2d):#输入
    81. idx0=np.squeeze(np.argwhere(np.asarray(start_mask.numpy())))#输入非0索引
    82. idx1=np.squeeze(np.argwhere(np.asarray(end_mask.numpy())))#输出非0索引
    83. if idx0.size==1:
    84. idx0=np.resize(idx0,(1,))
    85. if idx1.size==1:
    86. idx1=np.resize(idx1,(1,))
    87. w1=m0.weight.data[:,idx0.tolist(),:,:].clone()
    88. w1=w1[idx1.tolist(),:,:,:].clone()
    89. m1.weight.data=w1.clone()
    90. elif isinstance(m0,nn.Linear):
    91. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.numpy()))) # 输入非0索引
    92. if idx0.size==1:
    93. idx0=np.resize(idx0,(1,))
    94. m1.weight.data=m0.weight.data[:,idx0].clone()
    95. m1.bias.data=m0.bias.data.clone()
    96. torch.save(newmodel.state_dict(),"prune_net.pth")
    97. print(newmodel)

    新模型结构:

    1. 所有BN层总weight数量: 448
    2. layer index:3 total channel:64 remaining channel:29
    3. layer index:6 total channel:128 remaining channel:56
    4. layer index:9 total channel:256 remaining channel:75
    5. 剪枝通道占比: tensor(0.6429)
    6. [29, 56, 75, 'A']
    7. net(
    8. (features): Sequential(
    9. (0): Conv2d(3, 29, kernel_size=(7, 7), stride=(2, 2), padding=(1, 1), bias=False)
    10. (1): BatchNorm2d(29, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    11. (2): ReLU(inplace=True)
    12. (3): Conv2d(29, 56, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    13. (4): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    14. (5): ReLU(inplace=True)
    15. (6): Conv2d(56, 75, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    16. (7): BatchNorm2d(75, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    17. (8): ReLU(inplace=True)
    18. (9): AvgPool2d(kernel_size=2, stride=2, padding=0)
    19. )
    20. (linear): Linear(in_features=75, out_features=2, bias=True)
    21. )
    22. ----------------------------------------------------------------
    23. Layer (type) Output Shape Param #
    24. ================================================================
    25. Conv2d-1 [1, 29, 8, 8] 4,263
    26. BatchNorm2d-2 [1, 29, 8, 8] 58
    27. ReLU-3 [1, 29, 8, 8] 0
    28. Conv2d-4 [1, 56, 4, 4] 14,616
    29. BatchNorm2d-5 [1, 56, 4, 4] 112
    30. ReLU-6 [1, 56, 4, 4] 0
    31. Conv2d-7 [1, 75, 2, 2] 37,800
    32. BatchNorm2d-8 [1, 75, 2, 2] 150
    33. ReLU-9 [1, 75, 2, 2] 0
    34. AvgPool2d-10 [1, 75, 1, 1] 0
    35. Linear-11 [1, 2] 152
    36. ================================================================
    37. Total params: 57,151
    38. Trainable params: 57,151
    39. Non-trainable params: 0
    40. ----------------------------------------------------------------
    41. Input size (MB): 0.00
    42. Forward/backward pass size (MB): 0.07
    43. Params size (MB): 0.22
    44. Estimated Total Size (MB): 0.29
    45. ----------------------------------------------------------------

    模型大小由1.45m压缩到230k,压缩率:84%

    4、fine-tune训练

    1. newmodel.load_state_dict(torch.load("prune_net.pth"))
    2. #
    3. optimer=torch.optim.Adam(model.parameters())
    4. loss_fn=torch.nn.CrossEntropyLoss()
    5. for e in range(100):
    6. x = torch.rand((1, 3, 20, 20))
    7. y=torch.tensor(np.random.randint(0,2,(1))).long()
    8. out=newmodel(x)
    9. loss=loss_fn(out,y)
    10. optimer.zero_grad()
    11. loss.backward()
    12. optimer.step()
    13. torch.save(newmodel.state_dict(),"prune_net.pth")

    以上过程仅供参考。

     参考:GitHub - foolwood/pytorch-slimming: Learning Efficient Convolutional Networks through Network Slimming, In ICCV 2017.

    Network Slimming——有效的通道剪枝方法(Channel Pruning)_Law-Yao的博客-CSDN博客_通道剪枝算法

  • 相关阅读:
    【黑马-SpringCloud技术栈】【02】服务拆分及远程调用_服务提供者与消费者
    (Note)C++数值标识符
    基于Qt QList和QMap容器类示例
    理性分析不同模型的性能指标
    嵌入式(Linux内核的安装与加载)
    leetcode298周赛记录
    STL 应用 —— set / multiset
    (02)Cartographer源码无死角解析-(10) 配置文件加载
    linux清理缓存垃圾命令和方法介绍
    QT总结汇总
  • 原文地址:https://blog.csdn.net/m0_37264397/article/details/126157647