• Slimming剪枝方法


    本文参考:5-剪枝后模型参数赋值_哔哩哔哩_bilibiliz

     https://github.com/foolwood/pytorch-slimming

    一、模型剪枝理论说明

    论文:Learning Efficient Convolutional Networks through Network Slimming

    (1)卷积后得到多个特征图(channel=64, 128, 256…),这些图不一定都重要,所以量化计算特征图的重要性

    (2)训练模型的时候需要加入一些策略,让权重参数有明显的大小之分,从而筛选重要的特征图

    Channel scaling factors里面的数值为特征图的打分,直观理解为分值大的特征图需要保留,分值小的特征图可以去掉。

    二、计算特征图重要性

    Network slimming ,利用BN层中的缩放因子Ƴ

    BN的理论支持:

    ,使得数据为(0,1)正态分布。

    整体感觉是一个归一化操作,但是BN中需要额外引入两个可训练的参数:Ƴ和β

    BatchNorm的本质:

    (1)BN要做的就是把越来越偏离的分布给拉回来

    (2)再重新规范化到均值为0方差为1的标准正态分布

    (3)这样能够使得激活函数在数值层面更敏感,训练更快。

    (4)产生的问题:经过BN之后,把数值分布强制在了非线性函数的线性区域中

    针对第(3)点解释:

    在激活函数中,两边处于饱和区域不敏感,接近于0位置非饱和处于敏感区域。

    针对第(4)点解释:

    BN将数据强制压缩到中间红色区域的线性部分,F(x)只做仿射变化,F=sigmoid,多个仿射变化的叠加仍然是仿射变化,添加再多隐藏层与单层神经网络是等价的。

    所以,BN需要保证一些非线性,对规范后的结果再进行变化

    添加两个参数后重新训练:

    ,这两个参数是网络训练过程中得到的,而不是超参给的。

    该公式相当于BN的逆变换,

    相当于对正态分布进行一些改变,拉动一下,变一下形状,做适当的还原。

    Ƴ值越大越重要,那么该特征图调整的幅度越大,说明该特征图越重要

    三、让特征图重要度两极分化更明显

    使用L1正则化对参数进行稀疏操作。

    L1求导后为:sign(Θ),相当于稳定前进,都为1,最后学成0了

    L2求导后为:Θ,相当于越来越慢,很多参数都接近0,平滑。

    论文核心:

    四、剪枝流程

    使用到的vgg模型架构:

    1. import torch
    2. import torch.nn as nn
    3. import math
    4. from torch.autograd import Variable
    5. class vgg(nn.Module):
    6. def __init__(self, dataset='cifar10', init_weights=True, cfg=None):
    7. super(vgg, self).__init__()
    8. if cfg is None:
    9. cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
    10. self.feature = self.make_layers(cfg, True)
    11. if dataset == 'cifar10':
    12. num_classes = 10
    13. elif dataset == 'cifar100':
    14. num_classes = 100
    15. self.classifier = nn.Linear(cfg[-1], num_classes)
    16. if init_weights:
    17. self._initialize_weights()
    18. def make_layers(self, cfg, batch_norm=False):
    19. layers = []
    20. in_channels = 3
    21. for v in cfg:
    22. if v == 'M':
    23. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    24. else:
    25. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
    26. if batch_norm:
    27. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
    28. else:
    29. layers += [conv2d, nn.ReLU(inplace=True)]
    30. in_channels = v
    31. return nn.Sequential(*layers)
    32. def forward(self, x):
    33. x = self.feature(x)
    34. x = nn.AvgPool2d(2)(x)
    35. x = x.view(x.size(0), -1)
    36. y = self.classifier(x)
    37. return y
    38. def _initialize_weights(self):
    39. for m in self.modules():
    40. if isinstance(m, nn.Conv2d):
    41. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    42. m.weight.data.normal_(0, math.sqrt(2. / n))
    43. if m.bias is not None:
    44. m.bias.data.zero_()
    45. elif isinstance(m, nn.BatchNorm2d):
    46. m.weight.data.fill_(0.5)
    47. m.bias.data.zero_()
    48. elif isinstance(m, nn.Linear):
    49. m.weight.data.normal_(0, 0.01)
    50. m.bias.data.zero_()
    51. if __name__ == '__main__':
    52. net = vgg()
    53. x = Variable(torch.FloatTensor(16, 3, 40, 40))
    54. y = net(x)
    55. print(y.data.shape)

    1、原始模型训练:

    (1)BN的L1稀疏正则化:使用次梯度下降法,对BN层的权重进行再调整

    (2)训练完成后主要保存原始模型的参数信息

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. import torch.optim as optim
    5. from torchvision import datasets, transforms
    6. from torch.autograd import Variable
    7. from vgg import vgg
    8. import shutil
    9. from tqdm import tqdm
    10. learning_rate = 0.1
    11. momentum = 0.9
    12. weight_decay = 1e-4
    13. epochs = 3
    14. log_interval = 100
    15. batch_size = 100
    16. sparsity_regularization = True
    17. scale_sparse_rate = 0.0001
    18. checkpoint_model_path = 'checkpoint,pth.tar'
    19. best_model_path = 'model_best.pth.tar'
    20. train_loader = torch.utils.data.DataLoader(
    21. datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,
    22. transform=transforms.Compose([
    23. transforms.Pad(4),
    24. transforms.RandomCrop(32),
    25. transforms.RandomHorizontalFlip(),
    26. transforms.ToTensor(),
    27. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    28. ])),
    29. batch_size=batch_size, shuffle=True)
    30. test_loader = torch.utils.data.DataLoader(
    31. datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([
    32. transforms.ToTensor(),
    33. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    34. ])),
    35. batch_size=batch_size, shuffle=True)
    36. model = vgg()
    37. model.cuda()
    38. optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    39. def train(epoch):
    40. model.train()
    41. for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
    42. data, target = data.cuda(), target.cuda()
    43. data, target = Variable(data), Variable(target)
    44. optimizer.zero_grad()
    45. output = model(data)
    46. loss = F.cross_entropy(output, target)
    47. loss.backward()
    48. if sparsity_regularization:
    49. updateBN()
    50. optimizer.step()
    51. if batch_idx % log_interval == 0:
    52. print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
    53. epoch, batch_idx * len(data), len(train_loader.dataset),
    54. 100. * batch_idx / len(train_loader), loss.item()))
    55. def test():
    56. model.eval()
    57. test_loss = 0
    58. correct = 0
    59. for data, target in tqdm(test_loader):
    60. data , target = data.cuda(), target.cuda()
    61. data, target = Variable(data), Variable(target)
    62. output = model(data)
    63. test_loss += F.cross_entropy(output, target, size_average=False).item()
    64. pred = output.data.max(1, keepdim=True)[1]
    65. correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    66. test_loss /= len(test_loader.dataset)
    67. print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
    68. test_loss, correct, len(test_loader.dataset),
    69. 100. * correct / len(test_loader.dataset)))
    70. return correct / float(len(test_loader.dataset))
    71. def save_checkpoint(state, is_best, filename=checkpoint_model_path):
    72. torch.save(state, filename)
    73. if is_best:
    74. shutil.copyfile(filename, best_model_path)
    75. def updateBN():
    76. for m in model.modules():
    77. if isinstance(m, nn.BatchNorm2d):
    78. m.weight.grad.data.add_(scale_sparse_rate * torch.sign(m.weight.data)) # L1,使用次梯度下降
    79. best_prec = 0
    80. for epoch in range(epochs):
    81. train(epoch)
    82. prec = test()
    83. is_best = prec > best_prec
    84. best_prec = max(prec, best_prec)
    85. save_checkpoint({
    86. 'epoch': epoch + 1,
    87. 'state_dict': model.state_dict(),
    88. 'best_prec': best_prec,
    89. 'optimizer': optimizer.state_dict()
    90. }, is_best)

    2、模型剪枝

    (1)剪枝过程主要分为两部分:第一部分是计算mask,第二部分是根据mask调整各层的shape

    (2)BN层通道数:Conv -> BN -> ReLU -> MaxPool--à Linear,所以BN的输入维度对应Conv的输出通道数

    (3)BN层总通道数:将所有BN层的通道数进行汇总

    (4)BN层剪枝百分位:取总通道数的百分位得到具体的float值,大于该值的通道对应的mask置为1,否则对应的mask置为0

    (5)改变权重weight:BN层抽取mask为1的通道数的值,该操作会改变BN的shape,从而上下游操作中的Conv和Linear也需要被动做出调整,对Maxpool和ReLu的通道数无影响

    (6)Conv层的参数为[out_channels, in_channels, kernel_size1, kernel_size2],所以需要调整两次,先对in_channels进行调整,再对out_channels进行调整。Conv初始输入为RGB的3通道。

    假如计算出的保留通道数信息为:

    [48, 60, 115, 118, 175, 163, 141, 130, 259, 267, 258, 249, 225, 212, 234, 97]

    Conv的输入输出变为:

    In shape: 3 Out shape:48

    In shape: 48 Out shape:60

    In shape: 60 Out shape:115

    In shape: 115 Out shape:118

    ……

    In shape: 234 Out shape:97

    (7)保存模型时,一方面把有用的参数信息保存了下来,同时剪枝后的最新的模型结构参数也保存了,方便后续再训练时构建新的模型结构

    1. import os
    2. import torch
    3. import torch.nn as nn
    4. from torch.autograd import Variable
    5. from torchvision import datasets, transforms
    6. from vgg import vgg
    7. import numpy as np
    8. from tqdm import tqdm
    9. percent = 0.5
    10. batch_size = 100
    11. raw_model_path = 'model_best.pth.tar'
    12. save_model_path = 'prune_model.pth.tar'
    13. model = vgg()
    14. model.cuda()
    15. if os.path.isfile(raw_model_path):
    16. print("==> loading checkpoint '{}'".format(raw_model_path))
    17. checkpoint = torch.load(raw_model_path)
    18. start_epoch = checkpoint['epoch']
    19. best_prec = checkpoint['best_prec']
    20. model.load_state_dict(checkpoint['state_dict'])
    21. print("==> loaded checkpoint '{}'(epoch {}) Prec:{:f}".format(raw_model_path, start_epoch, best_prec) )
    22. print(model)
    23. total = 0
    24. for m in model.modules():
    25. if isinstance(m, nn.BatchNorm2d):
    26. total += m.weight.data.shape[0]
    27. bn = torch.zeros(total)
    28. index = 0
    29. for m in model.modules():
    30. if isinstance(m, nn.BatchNorm2d):
    31. size = m.weight.data.shape[0]
    32. bn[index : index + size] = m.weight.data.abs().clone()
    33. index += size
    34. y, i = torch.sort(bn)
    35. thre_index = int(total * percent)
    36. thre = y[thre_index]
    37. pruned = 0
    38. cfg = []
    39. cfg_mask = []
    40. for k, m in enumerate(model.modules()):
    41. if isinstance(m, nn.BatchNorm2d):
    42. weight_copy = m.weight.data.clone()
    43. mask = weight_copy.abs().gt(thre).float().cuda()
    44. pruned += mask.shape[0] - torch.sum(mask)
    45. m.weight.data.mul_(mask)
    46. m.bias.data.mul_(mask)
    47. cfg.append(int(torch.sum(mask)))
    48. cfg_mask.append(mask.clone())
    49. print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
    50. elif isinstance(m, nn.MaxPool2d):
    51. cfg.append('M')
    52. pruned_ratio = pruned / total
    53. print('pruned_ratio: {},Pre-processing Successful!'.format(pruned_ratio))
    54. # simple test model after Pre-processing prune(simple set BN scales to zeros)
    55. def test():
    56. test_loader = torch.utils.data.DataLoader(
    57. datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([
    58. transforms.ToTensor(),
    59. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    60. ])),
    61. batch_size=batch_size, shuffle=True)
    62. model.eval()
    63. correct = 0
    64. for data, target in tqdm(test_loader):
    65. data, target = data.cuda(), target.cuda()
    66. data, target = Variable(data), Variable(target)
    67. output = model(data)
    68. pred = output.data.max(1, keepdim=True)[1]
    69. correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    70. print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
    71. correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    72. test()
    73. # make real prune
    74. print(cfg)
    75. new_model = vgg(cfg=cfg)
    76. new_model.cuda()
    77. layer_id_in_cfg = 0 # cfg中的层数索引
    78. start_mask = torch.ones(3)
    79. end_mask = cfg_mask[layer_id_in_cfg]
    80. for [m0, m1] in zip(model.modules(), new_model.modules()):
    81. if isinstance(m0, nn.BatchNorm2d):
    82. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
    83. m1.weight.data = m0.weight.data[idx1].clone()
    84. m1.bias.data = m0.bias.data[idx1].clone()
    85. m1.running_mean = m0.running_mean[idx1].clone()
    86. m1.running_var = m0.running_var[idx1].clone()
    87. layer_id_in_cfg += 1
    88. start_mask = end_mask.clone()
    89. if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
    90. end_mask = cfg_mask[layer_id_in_cfg]
    91. elif isinstance(m0, nn.Conv2d):
    92. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
    93. idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
    94. print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
    95. w = m0.weight.data[:, idx0, :, :].clone()
    96. w = w[idx1, :, :, :].clone()
    97. m1.weight.data = w.clone()
    98. elif isinstance(m0, nn.Linear):
    99. idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
    100. m1.weight.data = m0.weight.data[:, idx0].clone()
    101. torch.save({'cfg': cfg, 'state_dict': new_model.state_dict()}, save_model_path)
    102. print(new_model)
    103. model = new_model
    104. test()

    3、再训练

    剪枝后保存的模型参数相当于训练过程中的一个checkpoint,根据新的模型结构,在此checkpoint的基础上再进行训练,直到得到满意的指标。

    1. import torch
    2. import torch.nn.functional as F
    3. import torch.optim as optim
    4. from torchvision import datasets, transforms
    5. from torch.autograd import Variable
    6. from vgg import vgg
    7. import shutil
    8. from tqdm import tqdm
    9. learning_rate = 0.1
    10. momentum = 0.9
    11. weight_decay = 1e-4
    12. epochs = 3
    13. log_interval = 100
    14. batch_size = 100
    15. sparsity_regularization = True
    16. scale_sparse_rate = 0.0001
    17. prune_model_path = 'prune_model.pth.tar'
    18. prune_checkpoint_path = 'pruned_checkpoint.pth.tar'
    19. prune_best_model_path = 'pruned_model_best.pth.tar'
    20. train_loader = torch.utils.data.DataLoader(
    21. datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,
    22. transform=transforms.Compose([
    23. transforms.Pad(4),
    24. transforms.RandomCrop(32),
    25. transforms.RandomHorizontalFlip(),
    26. transforms.ToTensor(),
    27. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    28. ])),
    29. batch_size=batch_size, shuffle=True)
    30. test_loader = torch.utils.data.DataLoader(
    31. datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([
    32. transforms.ToTensor(),
    33. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    34. ])),
    35. batch_size=batch_size, shuffle=True)
    36. checkpoint = torch.load(prune_model_path)
    37. model = vgg(cfg=checkpoint['cfg'])
    38. model.cuda()
    39. model.load_state_dict(checkpoint['state_dict'])
    40. optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    41. def train(epoch):
    42. model.train()
    43. for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
    44. data, target = data.cuda(), target.cuda()
    45. data, target = Variable(data), Variable(target)
    46. optimizer.zero_grad()
    47. output = model(data)
    48. loss = F.cross_entropy(output, target)
    49. loss.backward()
    50. optimizer.step()
    51. if batch_idx % log_interval == 0:
    52. print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
    53. epoch, batch_idx * len(data), len(train_loader.dataset),
    54. 100. * batch_idx / len(train_loader), loss.item()))
    55. def test():
    56. model.eval()
    57. test_loss = 0
    58. correct = 0
    59. for data, target in tqdm(test_loader):
    60. data , target = data.cuda(), target.cuda()
    61. data, target = Variable(data), Variable(target)
    62. output = model(data)
    63. test_loss += F.cross_entropy(output, target, size_average=False).item()
    64. pred = output.data.max(1, keepdim=True)[1]
    65. correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    66. test_loss /= len(test_loader.dataset)
    67. print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
    68. test_loss, correct, len(test_loader.dataset),
    69. 100. * correct / len(test_loader.dataset)))
    70. return correct / float(len(test_loader.dataset))
    71. def save_checkpoint(state, is_best, filename=prune_checkpoint_path):
    72. torch.save(state, filename)
    73. if is_best:
    74. shutil.copyfile(filename, prune_best_model_path)
    75. best_prec = 0
    76. for epoch in range(epochs):
    77. train(epoch)
    78. prec = test()
    79. is_best = prec > best_prec
    80. best_prec = max(prec, best_prec)
    81. save_checkpoint({
    82. 'epoch': epoch + 1,
    83. 'state_dict': model.state_dict(),
    84. 'best_prec': best_prec,
    85. 'optimizer': optimizer.state_dict()
    86. }, is_best)

    4、原始模型和剪枝后模型比较:

    在cifar10上通过vgg模型分别迭代3次。

    原始模型为156M,准确率为70%左右

    剪枝后模型为36M,准确率为76%左右

    备注:最好是原始模型达到顶峰时再剪枝,此时再比较剪枝前后的准确率影响。

     

  • 相关阅读:
    【Python大数据笔记_day05_Hive基础操作】
    刷题1:数组篇
    R语言将向量数据按照行方式转化为矩阵数据(设置参数byrow为TRUE)
    精品基于springboot的线上跳蚤市场平台
    给网站添加春节灯笼效果:引入即用,附源码!
    Bra12同态加密方案初步学习
    MybatisPlus的基础使用
    大话STL第三期——deque(双端动态数组)
    Spring Cloud 学习笔记(2 3)
    C++刷题 日期差值
  • 原文地址:https://blog.csdn.net/benben044/article/details/127886261