• 模型剪枝初级方法


    信息源:https://www.bilibili.com/video/BV147411W7am/?spm_id_from=333.788.recommend_more_video.2&vd_source=3969f30b089463e19db0cc5e8fe4583a

    1、剪枝的含义

    把不重要的参数去掉,计算就更快了,模型的大小就变小了(本文涉及的剪枝方式没有这个功能)。

    2、全连接层的剪枝

    上述剪枝就是把一些weight置为0,这样计算就更快了。

    计算掩码矩阵的过程:

     接下来要做的:

    (1)给每一层增加一个变量,用于存储mask

    (2)设计一个函数,用于计算mask

    3、卷积层剪枝

     假如有4个卷积核,计算每个卷积核的L2范数,哪个卷积核的范数值最小则对应的mask全部置为0.如上图灰色的部分。

    4、代码部分

    GitHub - mepeichun/Efficient-Neural-Network-Bilibili: B站Efficient-Neural-Network学习分享的配套代码

    5、全连接层剪枝

    (1)剪枝思路

    假设剪枝的比例为50%。

    找到每一个linear的layer,然后取参数的50%分位数,接着构造mask,所有大于50%分位数的mask位置置为1,所有小于等于50%分位数的mask位置置为0。

    最后weight * mask得到新的weight。

    (2)剪枝代码

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from torchvision import datasets, transforms
    5. import torch.utils.data
    6. import numpy as np
    7. import math
    8. from copy import deepcopy
    9. def to_var(x, requires_grad=False):
    10. if torch.cuda.is_available():
    11. x = x.cuda()
    12. return x.clone().detach().requires_grad_(requires_grad)
    13. class MaskedLinear(nn.Linear):
    14. def __init__(self, in_features, out_features, bias=True):
    15. super(MaskedLinear, self).__init__(in_features, out_features, bias)
    16. self.mask_flag = False
    17. self.mask = None
    18. def set_mask(self, mask):
    19. self.mask = to_var(mask, requires_grad=False)
    20. self.weight.data = self.weight.data * self.mask.data
    21. self.mask_flag = True
    22. def get_mask(self):
    23. print(self.mask_flag)
    24. return self.mask
    25. def forward(self, x):
    26. # 以下代码与set_mask中的self.weight.data = self.weight.data * self.mask.data重复了
    27. # if self.mask_flag:
    28. # weight = self.weight * self.mask
    29. # return F.linear(x, weight, self.bias)
    30. # else:
    31. # return F.linear(x, self.weight, self.bias)
    32. return F.linear(x, self.weight, self.bias)
    33. class MLP(nn.Module):
    34. def __init__(self):
    35. super(MLP, self).__init__()
    36. self.linear1 = MaskedLinear(28*28, 200)
    37. self.relu1 = nn.ReLU(inplace=True)
    38. self.linear2 = MaskedLinear(200, 200)
    39. self.relu2 = nn.ReLU(inplace=True)
    40. self.linear3 = MaskedLinear(200, 10)
    41. def forward(self, x):
    42. out = x.view(x.size(0), -1)
    43. out = self.relu1(self.linear1(out))
    44. out = self.relu2(self.linear2(out))
    45. out = self.linear3(out)
    46. return out
    47. def set_masks(self, masks):
    48. self.linear1.set_mask(masks[0])
    49. self.linear2.set_mask(masks[1])
    50. self.linear3.set_mask(masks[2])
    51. def train(model, device, train_loader, optimizer, epoch):
    52. model.train()
    53. total = 0
    54. for batch_idx, (data, target) in enumerate(train_loader):
    55. data, target = data.to(device), target.to(device)
    56. optimizer.zero_grad()
    57. output = model(data)
    58. loss = F.cross_entropy(output, target)
    59. loss.backward()
    60. optimizer.step()
    61. total += len(data)
    62. progress = math.ceil(batch_idx / len(train_loader) * 50)
    63. print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
    64. (epoch, total, len(train_loader.dataset),
    65. '-' * progress + '>', progress * 2), end='')
    66. def test(model, device, test_loader):
    67. model.eval()
    68. test_loss = 0
    69. correct = 0
    70. with torch.no_grad():
    71. for data, target in test_loader:
    72. data, target = data.to(device), target.to(device)
    73. output = model(data)
    74. test_loss += F.cross_entropy(output, target, reduction='sum').item()
    75. pred = output.argmax(dim=1, keepdim=True)
    76. correct += pred.eq(target.view_as(pred)).sum().item()
    77. test_loss /= len(test_loader.dataset)
    78. print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
    79. test_loss, correct, len(test_loader.dataset),
    80. 100. * correct / len(test_loader.dataset)))
    81. return test_loss, correct / len(test_loader.dataset)
    82. def weight_prune(model, pruning_perc):
    83. threshold_list = []
    84. for p in model.parameters():
    85. if len(p.data.size()) != 1: # bias
    86. weight = p.cpu().data.abs().numpy().flatten()
    87. threshold = np.percentile(weight, pruning_perc)
    88. threshold_list.append(threshold)
    89. # generate mask
    90. masks = []
    91. idx = 0
    92. for p in model.parameters():
    93. if len(p.data.size()) != 1:
    94. pruned_inds = p.data.abs() > threshold_list[idx]
    95. masks.append(pruned_inds.float())
    96. idx += 1
    97. return masks
    98. def main():
    99. epochs = 2
    100. batch_size = 64
    101. torch.manual_seed(0)
    102. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    103. train_loader = torch.utils.data.DataLoader(
    104. datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
    105. transform=transforms.Compose([
    106. transforms.ToTensor(),
    107. transforms.Normalize((0.1307,), (0.3081,))
    108. ])),
    109. batch_size=batch_size, shuffle=True)
    110. test_loader = torch.utils.data.DataLoader(
    111. datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
    112. transforms.ToTensor(),
    113. transforms.Normalize((0.1307,), (0.3081,))
    114. ])),
    115. batch_size=1000, shuffle=True)
    116. model = MLP().to(device)
    117. optimizer = torch.optim.Adadelta(model.parameters())
    118. for epoch in range(1, epochs + 1):
    119. train(model, device, train_loader, optimizer, epoch)
    120. _, acc = test(model, device, test_loader)
    121. print("\n=====Pruning 60%=======\n")
    122. pruned_model = deepcopy(model)
    123. mask = weight_prune(pruned_model, 60)
    124. pruned_model.set_masks(mask)
    125. test(pruned_model, device, test_loader)
    126. return model, pruned_model
    127. model, pruned_model = main()
    128. torch.save(model.state_dict(), ".model.pth")
    129. torch.save(pruned_model.state_dict(), ".pruned_model.pth")
    130. from matplotlib import pyplot as plt
    131. def plot_weights(model):
    132. modules = [module for module in model.modules()]
    133. num_sub_plot = 0
    134. for i, layer in enumerate(modules):
    135. if hasattr(layer, 'weight'):
    136. plt.subplot(131+num_sub_plot)
    137. w = layer.weight.data
    138. w_one_dim = w.cpu().numpy().flatten()
    139. plt.hist(w_one_dim[w_one_dim != 0], bins=50)
    140. num_sub_plot += 1
    141. plt.show()
    142. model = MLP()
    143. pruned_model = MLP()
    144. model.load_state_dict(torch.load('.model.pth'))
    145. pruned_model.load_state_dict(torch.load('.pruned_model.pth'))
    146. plot_weights(model)
    147. plot_weights(pruned_model)

    (3)剪枝前后精确度信息

    Train epoch 1: 60000/60000, [-------------------------------------------------->]

    100%

    Test: average loss: 0.1391, accuracy: 9562/10000 (96%)

    Train epoch 2: 60000/60000, [-------------------------------------------------->]

    100%

    Test: average loss: 0.0870, accuracy: 9741/10000 (97%)  

    =====Pruning 60%=======

    Test: average loss: 0.0977, accuracy: 9719/10000 (97%)

    通过数据,可以发现剪枝前后准确率并未下降太多。

    (4)剪枝前后模型参数数据分布

    剪枝前的分布:

    剪枝后的分布:

    6、卷积层剪枝

    (1)剪枝思路

    假设剪枝的比例为50%。

    • 对于每一个layer的cnn卷积层,计算其参数的L2范数值,
    • 然后将数值通过sum()操作聚合到channel维度上,接着将该值在channel维度上归一化,取非零值中的最小值和对应的channel索引值。
    • 多个layer比较各自的最小值,取最小的值及对应的channel索引值对应的mask置为0
    • 计算所有参数中零值的比例,一直重复以上3步直到零值的比例达到剪枝的比例。

    每一个layer的weight * mask就得到了新的weight。

    (2)剪枝代码

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from torchvision import datasets, transforms
    5. import torch.utils.data
    6. import numpy as np
    7. import math
    8. def to_var(x, requires_grad=False):
    9. if torch.cuda.is_available():
    10. x = x.cuda()
    11. return x.clone().detach().requires_grad_(requires_grad)
    12. class MaskedConv2d(nn.Conv2d):
    13. def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
    14. super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    15. self.mask_flag = False
    16. def set_mask(self, mask):
    17. self.mask = to_var(mask, requires_grad=False)
    18. self.weight.data = self.weight.data * self.mask.data
    19. self.mask_flag = True
    20. def get_mask(self):
    21. print(self.mask_flag)
    22. return self.mask
    23. def forward(self, x):
    24. # 以下部分与set_mask的self.weight.data = self.weight.data * self.mask.data重合
    25. # if self.mask_flag == True:
    26. # weight = self.weight * self.mask
    27. # return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    28. # else:
    29. # return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    30. return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
    31. class ConvNet(nn.Module):
    32. def __init__(self):
    33. super(ConvNet, self).__init__()
    34. self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
    35. self.relu1 = nn.ReLU(inplace=True)
    36. self.maxpool1 = nn.MaxPool2d(2)
    37. self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1)
    38. self.relu2 = nn.ReLU(inplace=True)
    39. self.maxpool2 = nn.MaxPool2d(2)
    40. self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
    41. self.relu3 = nn.ReLU(inplace=True)
    42. self.linear1 = nn.Linear(7*7*64, 10)
    43. def forward(self, x):
    44. out = self.maxpool1(self.relu1(self.conv1(x)))
    45. out = self.maxpool2(self.relu2(self.conv2(out)))
    46. out = self.relu3(self.conv3(out))
    47. out = out.view(out.size(0), -1)
    48. out = self.linear1(out)
    49. return out
    50. def set_masks(self, masks):
    51. self.conv1.set_mask(torch.from_numpy(masks[0]))
    52. self.conv2.set_mask(torch.from_numpy(masks[1]))
    53. self.conv3.set_mask(torch.from_numpy(masks[2]))
    54. def train(model, device, train_loader, optimizer, epoch):
    55. model.train()
    56. total = 0
    57. for batch_idx, (data, target) in enumerate(train_loader):
    58. data, target = data.to(device), target.to(device)
    59. optimizer.zero_grad()
    60. output = model(data)
    61. loss = F.cross_entropy(output, target)
    62. loss.backward()
    63. optimizer.step()
    64. total += len(data)
    65. progress = math.ceil(batch_idx / len(train_loader) * 50)
    66. print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
    67. (epoch, total, len(train_loader.dataset),
    68. '-' * progress + '>', progress * 2), end='')
    69. def test(model, device, test_loader):
    70. model.eval()
    71. test_loss = 0
    72. correct = 0
    73. with torch.no_grad():
    74. for data, target in test_loader:
    75. data, target = data.to(device), target.to(device)
    76. output = model(data)
    77. test_loss += F.cross_entropy(output, target, reduction='sum').item()
    78. pred = output.argmax(dim=1, keepdim=True)
    79. correct += pred.eq(target.view_as(pred)).sum().item()
    80. test_loss /= len(test_loader.dataset)
    81. print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
    82. test_loss, correct, len(test_loader.dataset),
    83. 100. * correct / len(test_loader.dataset)))
    84. return test_loss, correct / len(test_loader.dataset)
    85. def prune_rate(model, verbose=False):
    86. """
    87. 计算模型的裁剪比例
    88. :param model:
    89. :param verbose:
    90. :return:
    91. """
    92. total_nb_param = 0
    93. nb_zero_param = 0
    94. layer_id = 0
    95. for parameter in model.parameters():
    96. param_this_layer = 1
    97. for dim in parameter.data.size():
    98. param_this_layer *= dim
    99. total_nb_param += param_this_layer
    100. # only pruning linear and conv layers
    101. if len(parameter.data.size()) != 1:
    102. layer_id += 1
    103. zero_param_this_layer = np.count_nonzero(parameter.cpu().data.numpy() == 0)
    104. nb_zero_param += zero_param_this_layer
    105. if verbose:
    106. print("Layer {} | {} layer | {:.2f}% parameters pruned" \
    107. .format(
    108. layer_id,
    109. 'Conv' if len(parameter.data.size()) == 4 \
    110. else 'Linear',
    111. 100. * zero_param_this_layer / param_this_layer,
    112. ))
    113. pruning_perc = 100. * nb_zero_param / total_nb_param
    114. if verbose:
    115. print("Final pruning rate: {:.2f}%".format(pruning_perc))
    116. return pruning_perc
    117. def arg_nonzero_min(a):
    118. """
    119. 获取非零值中的最小值及其下标值
    120. :param a:
    121. :return:
    122. """
    123. if not a:
    124. return
    125. min_ix, min_v = None, None
    126. # 查看是否所有值都为0
    127. for i, e in enumerate(a):
    128. if e != 0:
    129. min_ix = i
    130. min_v = e
    131. break
    132. if min_ix is None:
    133. print('Warning: all zero')
    134. return np.inf, np.inf
    135. # search for the smallest nonzero
    136. for i, e in enumerate(a):
    137. if e < min_v and e != 0:
    138. min_v = e
    139. min_ix = i
    140. return min_v, min_ix
    141. def prune_one_filter(model, masks):
    142. """
    143. pruning one least import feature map by the scaled l2norm of kernel weights
    144. 用缩放的核权重l2范数修剪最小输入特征图
    145. :param model:
    146. :param masks:
    147. :return:
    148. """
    149. NO_MASKS = False
    150. # construct masks if there is not yet
    151. if not masks:
    152. masks = []
    153. NO_MASKS = True
    154. values = []
    155. for p in model.parameters():
    156. if len(p.data.size()) == 4:
    157. p_np = p.data.cpu().numpy()
    158. # construct masks if there is not
    159. if NO_MASKS:
    160. masks.append(np.ones(p_np.shape).astype('float32'))
    161. # find the scaled l2 norm for each filter this layer
    162. value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1).sum(axis=1) / (p_np.shape[1] * p_np.shape[2] * p_np.shape[3])
    163. # normalization(important)
    164. value_this_layer = value_this_layer / np.sqrt(np.square(value_this_layer).sum())
    165. min_value, min_ind = arg_nonzero_min(list(value_this_layer))
    166. values.append([min_value, min_ind])
    167. assert len(masks) == len(values), "something wrong here"
    168. values = np.array(values) # [[min_value, min_ind], [min_value, min_ind], [min_value, min_ind]]
    169. # set mask corresponding to the filter to prune
    170. to_prune_layer_ind = np.argmin(values[:, 0])
    171. to_prune_filter_ind = int(values[to_prune_layer_ind, 1])
    172. masks[to_prune_layer_ind][to_prune_filter_ind] = 0.
    173. return masks
    174. def filter_prune(model, pruning_perc):
    175. """
    176. 剪枝主流程,不停剪枝直到裁剪比例达到要求
    177. :param model:
    178. :param pruning_perc:
    179. :return:
    180. """
    181. masks = []
    182. current_pruning_perc = 0
    183. while current_pruning_perc < pruning_perc:
    184. masks = prune_one_filter(model, masks)
    185. model.set_masks(masks)
    186. current_pruning_perc = prune_rate(model, verbose=False)
    187. print('{:.2f} pruned'.format(current_pruning_perc))
    188. return masks
    189. def main():
    190. epochs = 2
    191. batch_size = 64
    192. torch.manual_seed(0)
    193. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    194. train_loader = torch.utils.data.DataLoader(
    195. datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
    196. transform=transforms.Compose([
    197. transforms.ToTensor(),
    198. transforms.Normalize((0.1307,), (0.3081,))
    199. ])),
    200. batch_size=batch_size, shuffle=True)
    201. test_loader = torch.utils.data.DataLoader(
    202. datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
    203. transforms.ToTensor(),
    204. transforms.Normalize((0.1307,), (0.3081,))
    205. ])),
    206. batch_size=1000, shuffle=True)
    207. model = ConvNet().to(device)
    208. optimizer = torch.optim.Adadelta(model.parameters())
    209. for epoch in range(1, epochs + 1):
    210. train(model, device, train_loader, optimizer, epoch)
    211. _, acc = test(model, device, test_loader)
    212. print('\npruning 50%')
    213. mask = filter_prune(model, 50)
    214. model.set_masks(mask)
    215. _, acc = test(model, device, test_loader)
    216. # finetune
    217. print('\nfinetune')
    218. train(model, device, train_loader, optimizer, epoch)
    219. _, acc = test(model, device, test_loader)
    220. main()

     (3)精确度及剪枝比例信息:

    1. Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
    2. Test: average loss: 0.0505, accuracy: 9833/10000 (98%)
    3. Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
    4. Test: average loss: 0.0311, accuracy: 9893/10000 (99%)
    5. pruning 50%
    6. 0.66 pruned
    7. 1.32 pruned
    8. 1.65 pruned
    9. 1.98 pruned
    10. 2.31 pruned
    11. 2.64 pruned
    12. 2.98 pruned
    13. 3.64 pruned
    14. 3.97 pruned
    15. 4.63 pruned
    16. 4.64 pruned
    17. 4.65 pruned
    18. 4.98 pruned
    19. 5.31 pruned
    20. 5.32 pruned
    21. 5.65 pruned
    22. 6.31 pruned
    23. 6.97 pruned
    24. 7.30 pruned
    25. 7.63 pruned
    26. 8.30 pruned
    27. 8.31 pruned
    28. 8.97 pruned
    29. 9.30 pruned
    30. 9.96 pruned
    31. 10.29 pruned
    32. 10.95 pruned
    33. 11.61 pruned
    34. 11.94 pruned
    35. 12.60 pruned
    36. 13.27 pruned
    37. 13.93 pruned
    38. 14.26 pruned
    39. 14.92 pruned
    40. 15.25 pruned
    41. 15.26 pruned
    42. 15.59 pruned
    43. 16.25 pruned
    44. 16.91 pruned
    45. 17.57 pruned
    46. 17.90 pruned
    47. 18.23 pruned
    48. 18.90 pruned
    49. 19.56 pruned
    50. 19.89 pruned
    51. 20.55 pruned
    52. 20.88 pruned
    53. 21.54 pruned
    54. 21.87 pruned
    55. 21.88 pruned
    56. 22.54 pruned
    57. 22.87 pruned
    58. 23.53 pruned
    59. 24.20 pruned
    60. 24.21 pruned
    61. 24.87 pruned
    62. 25.20 pruned
    63. 25.86 pruned
    64. 26.19 pruned
    65. 26.20 pruned
    66. 26.86 pruned
    67. 27.19 pruned
    68. 27.52 pruned
    69. 28.18 pruned
    70. 28.51 pruned
    71. 29.18 pruned
    72. 29.51 pruned
    73. 29.52 pruned
    74. 29.85 pruned
    75. 29.86 pruned
    76. 30.52 pruned
    77. 30.85 pruned
    78. 31.51 pruned
    79. 32.17 pruned
    80. 32.83 pruned
    81. 33.16 pruned
    82. 33.82 pruned
    83. 34.16 pruned
    84. 34.82 pruned
    85. 35.15 pruned
    86. 35.48 pruned
    87. 36.14 pruned
    88. 36.47 pruned
    89. 37.13 pruned
    90. 37.79 pruned
    91. 37.80 pruned
    92. 38.13 pruned
    93. 38.79 pruned
    94. 38.80 pruned
    95. 39.13 pruned
    96. 39.15 pruned
    97. 39.81 pruned
    98. 40.14 pruned
    99. 40.47 pruned
    100. 40.48 pruned
    101. 41.14 pruned
    102. 41.47 pruned
    103. 41.80 pruned
    104. 41.81 pruned
    105. 42.47 pruned
    106. 43.13 pruned
    107. 43.46 pruned
    108. 43.79 pruned
    109. 44.46 pruned
    110. 44.79 pruned
    111. 44.80 pruned
    112. 45.46 pruned
    113. 45.79 pruned
    114. 45.80 pruned
    115. 46.46 pruned
    116. 46.79 pruned
    117. 47.12 pruned
    118. 47.78 pruned
    119. 47.79 pruned
    120. 47.80 pruned
    121. 48.13 pruned
    122. 48.79 pruned
    123. 49.13 pruned
    124. 49.79 pruned
    125. 49.80 pruned
    126. 50.46 pruned
    127. Test: average loss: 1.6824, accuracy: 6513/10000 (65%)
    128. finetune
    129. Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
    130. Test: average loss: 0.0324, accuracy: 9889/10000 (99%)

    可以看到,剪枝完成后直接测试准确率只有65%非常低,重新对weight中的非零参数训练一次后立马接近之前的准确率。

  • 相关阅读:
    GAT网络为什么占用那么多的显存
    Redis基础
    Office登录一直转圈Win10怎么解决?
    ps插件:alpaca增效工具 (完美替代AI创成式填充) 2.8.1 中文版
    【另类加法】
    6、STL、迭代器、容器
    vue3.0+ts+element ui中如何使用svg图片
    JavaScript中如何确定this的值?如何指定this的值?
    Linux如何安装Maven?
    03. 01- 代理模式(Proxy)
  • 原文地址:https://blog.csdn.net/benben044/article/details/127848606