信息源:https://www.bilibili.com/video/BV147411W7am/?spm_id_from=333.788.recommend_more_video.2&vd_source=3969f30b089463e19db0cc5e8fe4583a
1、剪枝的含义
把不重要的参数去掉,计算就更快了,模型的大小就变小了(本文涉及的剪枝方式没有这个功能)。
2、全连接层的剪枝
上述剪枝就是把一些weight置为0,这样计算就更快了。
计算掩码矩阵的过程:
接下来要做的:
(1)给每一层增加一个变量,用于存储mask
(2)设计一个函数,用于计算mask
3、卷积层剪枝
假如有4个卷积核,计算每个卷积核的L2范数,哪个卷积核的范数值最小则对应的mask全部置为0.如上图灰色的部分。
4、代码部分
GitHub - mepeichun/Efficient-Neural-Network-Bilibili: B站Efficient-Neural-Network学习分享的配套代码
5、全连接层剪枝
(1)剪枝思路
假设剪枝的比例为50%。
找到每一个linear的layer,然后取参数的50%分位数,接着构造mask,所有大于50%分位数的mask位置置为1,所有小于等于50%分位数的mask位置置为0。
最后weight * mask得到新的weight。
(2)剪枝代码
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchvision import datasets, transforms
- import torch.utils.data
- import numpy as np
- import math
- from copy import deepcopy
-
- def to_var(x, requires_grad=False):
- if torch.cuda.is_available():
- x = x.cuda()
-
- return x.clone().detach().requires_grad_(requires_grad)
-
- class MaskedLinear(nn.Linear):
- def __init__(self, in_features, out_features, bias=True):
- super(MaskedLinear, self).__init__(in_features, out_features, bias)
- self.mask_flag = False
- self.mask = None
-
- def set_mask(self, mask):
- self.mask = to_var(mask, requires_grad=False)
- self.weight.data = self.weight.data * self.mask.data
- self.mask_flag = True
-
- def get_mask(self):
- print(self.mask_flag)
- return self.mask
-
- def forward(self, x):
- # 以下代码与set_mask中的self.weight.data = self.weight.data * self.mask.data重复了
- # if self.mask_flag:
- # weight = self.weight * self.mask
- # return F.linear(x, weight, self.bias)
- # else:
- # return F.linear(x, self.weight, self.bias)
- return F.linear(x, self.weight, self.bias)
-
- class MLP(nn.Module):
- def __init__(self):
- super(MLP, self).__init__()
- self.linear1 = MaskedLinear(28*28, 200)
- self.relu1 = nn.ReLU(inplace=True)
- self.linear2 = MaskedLinear(200, 200)
- self.relu2 = nn.ReLU(inplace=True)
- self.linear3 = MaskedLinear(200, 10)
-
- def forward(self, x):
- out = x.view(x.size(0), -1)
- out = self.relu1(self.linear1(out))
- out = self.relu2(self.linear2(out))
- out = self.linear3(out)
- return out
-
- def set_masks(self, masks):
- self.linear1.set_mask(masks[0])
- self.linear2.set_mask(masks[1])
- self.linear3.set_mask(masks[2])
-
- def train(model, device, train_loader, optimizer, epoch):
- model.train()
- total = 0
- for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss = F.cross_entropy(output, target)
- loss.backward()
- optimizer.step()
-
- total += len(data)
- progress = math.ceil(batch_idx / len(train_loader) * 50)
- print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
- (epoch, total, len(train_loader.dataset),
- '-' * progress + '>', progress * 2), end='')
-
- def test(model, device, test_loader):
- model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for data, target in test_loader:
- data, target = data.to(device), target.to(device)
- output = model(data)
- test_loss += F.cross_entropy(output, target, reduction='sum').item()
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(target.view_as(pred)).sum().item()
-
- test_loss /= len(test_loader.dataset)
- print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
- test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
- return test_loss, correct / len(test_loader.dataset)
-
- def weight_prune(model, pruning_perc):
- threshold_list = []
- for p in model.parameters():
- if len(p.data.size()) != 1: # bias
- weight = p.cpu().data.abs().numpy().flatten()
- threshold = np.percentile(weight, pruning_perc)
- threshold_list.append(threshold)
-
- # generate mask
- masks = []
- idx = 0
- for p in model.parameters():
- if len(p.data.size()) != 1:
- pruned_inds = p.data.abs() > threshold_list[idx]
- masks.append(pruned_inds.float())
- idx += 1
- return masks
-
- def main():
- epochs = 2
- batch_size = 64
- torch.manual_seed(0)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=batch_size, shuffle=True)
- test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=1000, shuffle=True)
-
- model = MLP().to(device)
- optimizer = torch.optim.Adadelta(model.parameters())
-
- for epoch in range(1, epochs + 1):
- train(model, device, train_loader, optimizer, epoch)
- _, acc = test(model, device, test_loader)
-
- print("\n=====Pruning 60%=======\n")
- pruned_model = deepcopy(model)
- mask = weight_prune(pruned_model, 60)
- pruned_model.set_masks(mask)
- test(pruned_model, device, test_loader)
-
- return model, pruned_model
-
- model, pruned_model = main()
- torch.save(model.state_dict(), ".model.pth")
- torch.save(pruned_model.state_dict(), ".pruned_model.pth")
-
- from matplotlib import pyplot as plt
-
- def plot_weights(model):
- modules = [module for module in model.modules()]
- num_sub_plot = 0
- for i, layer in enumerate(modules):
- if hasattr(layer, 'weight'):
- plt.subplot(131+num_sub_plot)
- w = layer.weight.data
- w_one_dim = w.cpu().numpy().flatten()
- plt.hist(w_one_dim[w_one_dim != 0], bins=50)
- num_sub_plot += 1
- plt.show()
-
- model = MLP()
- pruned_model = MLP()
- model.load_state_dict(torch.load('.model.pth'))
- pruned_model.load_state_dict(torch.load('.pruned_model.pth'))
- plot_weights(model)
- plot_weights(pruned_model)
(3)剪枝前后精确度信息
Train epoch 1: 60000/60000, [-------------------------------------------------->]
100%
Test: average loss: 0.1391, accuracy: 9562/10000 (96%)
Train epoch 2: 60000/60000, [-------------------------------------------------->]
100%
Test: average loss: 0.0870, accuracy: 9741/10000 (97%)
=====Pruning 60%=======
Test: average loss: 0.0977, accuracy: 9719/10000 (97%)
通过数据,可以发现剪枝前后准确率并未下降太多。
(4)剪枝前后模型参数数据分布
剪枝前的分布:
剪枝后的分布:
6、卷积层剪枝
(1)剪枝思路
假设剪枝的比例为50%。
每一个layer的weight * mask就得到了新的weight。
(2)剪枝代码
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torchvision import datasets, transforms
- import torch.utils.data
- import numpy as np
- import math
-
- def to_var(x, requires_grad=False):
- if torch.cuda.is_available():
- x = x.cuda()
-
- return x.clone().detach().requires_grad_(requires_grad)
-
- class MaskedConv2d(nn.Conv2d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
- super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
- self.mask_flag = False
-
- def set_mask(self, mask):
- self.mask = to_var(mask, requires_grad=False)
- self.weight.data = self.weight.data * self.mask.data
- self.mask_flag = True
-
- def get_mask(self):
- print(self.mask_flag)
- return self.mask
-
- def forward(self, x):
- # 以下部分与set_mask的self.weight.data = self.weight.data * self.mask.data重合
- # if self.mask_flag == True:
- # weight = self.weight * self.mask
- # return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
- # else:
- # return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
- return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
-
- class ConvNet(nn.Module):
- def __init__(self):
- super(ConvNet, self).__init__()
-
- self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
- self.relu1 = nn.ReLU(inplace=True)
- self.maxpool1 = nn.MaxPool2d(2)
-
- self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1)
- self.relu2 = nn.ReLU(inplace=True)
- self.maxpool2 = nn.MaxPool2d(2)
-
- self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
- self.relu3 = nn.ReLU(inplace=True)
-
- self.linear1 = nn.Linear(7*7*64, 10)
-
- def forward(self, x):
- out = self.maxpool1(self.relu1(self.conv1(x)))
- out = self.maxpool2(self.relu2(self.conv2(out)))
- out = self.relu3(self.conv3(out))
- out = out.view(out.size(0), -1)
- out = self.linear1(out)
- return out
-
- def set_masks(self, masks):
- self.conv1.set_mask(torch.from_numpy(masks[0]))
- self.conv2.set_mask(torch.from_numpy(masks[1]))
- self.conv3.set_mask(torch.from_numpy(masks[2]))
-
- def train(model, device, train_loader, optimizer, epoch):
- model.train()
- total = 0
- for batch_idx, (data, target) in enumerate(train_loader):
- data, target = data.to(device), target.to(device)
- optimizer.zero_grad()
- output = model(data)
- loss = F.cross_entropy(output, target)
- loss.backward()
- optimizer.step()
-
- total += len(data)
- progress = math.ceil(batch_idx / len(train_loader) * 50)
- print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
- (epoch, total, len(train_loader.dataset),
- '-' * progress + '>', progress * 2), end='')
-
- def test(model, device, test_loader):
- model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for data, target in test_loader:
- data, target = data.to(device), target.to(device)
- output = model(data)
- test_loss += F.cross_entropy(output, target, reduction='sum').item()
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(target.view_as(pred)).sum().item()
-
- test_loss /= len(test_loader.dataset)
- print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
- test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
- return test_loss, correct / len(test_loader.dataset)
-
- def prune_rate(model, verbose=False):
- """
- 计算模型的裁剪比例
- :param model:
- :param verbose:
- :return:
- """
- total_nb_param = 0
- nb_zero_param = 0
- layer_id = 0
-
- for parameter in model.parameters():
- param_this_layer = 1
- for dim in parameter.data.size():
- param_this_layer *= dim
- total_nb_param += param_this_layer
-
- # only pruning linear and conv layers
- if len(parameter.data.size()) != 1:
- layer_id += 1
- zero_param_this_layer = np.count_nonzero(parameter.cpu().data.numpy() == 0)
- nb_zero_param += zero_param_this_layer
-
- if verbose:
- print("Layer {} | {} layer | {:.2f}% parameters pruned" \
- .format(
- layer_id,
- 'Conv' if len(parameter.data.size()) == 4 \
- else 'Linear',
- 100. * zero_param_this_layer / param_this_layer,
- ))
- pruning_perc = 100. * nb_zero_param / total_nb_param
- if verbose:
- print("Final pruning rate: {:.2f}%".format(pruning_perc))
- return pruning_perc
-
- def arg_nonzero_min(a):
- """
- 获取非零值中的最小值及其下标值
- :param a:
- :return:
- """
- if not a:
- return
-
- min_ix, min_v = None, None
- # 查看是否所有值都为0
- for i, e in enumerate(a):
- if e != 0:
- min_ix = i
- min_v = e
- break
- if min_ix is None:
- print('Warning: all zero')
- return np.inf, np.inf
-
- # search for the smallest nonzero
- for i, e in enumerate(a):
- if e < min_v and e != 0:
- min_v = e
- min_ix = i
-
- return min_v, min_ix
-
- def prune_one_filter(model, masks):
- """
- pruning one least import feature map by the scaled l2norm of kernel weights
- 用缩放的核权重l2范数修剪最小输入特征图
- :param model:
- :param masks:
- :return:
- """
- NO_MASKS = False
- # construct masks if there is not yet
- if not masks:
- masks = []
- NO_MASKS = True
-
- values = []
- for p in model.parameters():
- if len(p.data.size()) == 4:
- p_np = p.data.cpu().numpy()
-
- # construct masks if there is not
- if NO_MASKS:
- masks.append(np.ones(p_np.shape).astype('float32'))
-
- # find the scaled l2 norm for each filter this layer
- value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1).sum(axis=1) / (p_np.shape[1] * p_np.shape[2] * p_np.shape[3])
-
- # normalization(important)
- value_this_layer = value_this_layer / np.sqrt(np.square(value_this_layer).sum())
- min_value, min_ind = arg_nonzero_min(list(value_this_layer))
- values.append([min_value, min_ind])
-
- assert len(masks) == len(values), "something wrong here"
-
- values = np.array(values) # [[min_value, min_ind], [min_value, min_ind], [min_value, min_ind]]
-
- # set mask corresponding to the filter to prune
- to_prune_layer_ind = np.argmin(values[:, 0])
- to_prune_filter_ind = int(values[to_prune_layer_ind, 1])
- masks[to_prune_layer_ind][to_prune_filter_ind] = 0.
-
- return masks
-
- def filter_prune(model, pruning_perc):
- """
- 剪枝主流程,不停剪枝直到裁剪比例达到要求
- :param model:
- :param pruning_perc:
- :return:
- """
- masks = []
- current_pruning_perc = 0
-
- while current_pruning_perc < pruning_perc:
- masks = prune_one_filter(model, masks)
- model.set_masks(masks)
- current_pruning_perc = prune_rate(model, verbose=False)
- print('{:.2f} pruned'.format(current_pruning_perc))
-
- return masks
-
- def main():
- epochs = 2
- batch_size = 64
- torch.manual_seed(0)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('D:/ai_data/mnist_dataset', train=True, download=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=batch_size, shuffle=True)
- test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('D:/ai_data/mnist_dataset', train=False, download=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=1000, shuffle=True)
-
- model = ConvNet().to(device)
- optimizer = torch.optim.Adadelta(model.parameters())
-
- for epoch in range(1, epochs + 1):
- train(model, device, train_loader, optimizer, epoch)
- _, acc = test(model, device, test_loader)
-
- print('\npruning 50%')
- mask = filter_prune(model, 50)
- model.set_masks(mask)
- _, acc = test(model, device, test_loader)
-
- # finetune
- print('\nfinetune')
- train(model, device, train_loader, optimizer, epoch)
- _, acc = test(model, device, test_loader)
-
- main()
(3)精确度及剪枝比例信息:
- Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
- Test: average loss: 0.0505, accuracy: 9833/10000 (98%)
- Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
- Test: average loss: 0.0311, accuracy: 9893/10000 (99%)
-
- pruning 50%
- 0.66 pruned
- 1.32 pruned
- 1.65 pruned
- 1.98 pruned
- 2.31 pruned
- 2.64 pruned
- 2.98 pruned
- 3.64 pruned
- 3.97 pruned
- 4.63 pruned
- 4.64 pruned
- 4.65 pruned
- 4.98 pruned
- 5.31 pruned
- 5.32 pruned
- 5.65 pruned
- 6.31 pruned
- 6.97 pruned
- 7.30 pruned
- 7.63 pruned
- 8.30 pruned
- 8.31 pruned
- 8.97 pruned
- 9.30 pruned
- 9.96 pruned
- 10.29 pruned
- 10.95 pruned
- 11.61 pruned
- 11.94 pruned
- 12.60 pruned
- 13.27 pruned
- 13.93 pruned
- 14.26 pruned
- 14.92 pruned
- 15.25 pruned
- 15.26 pruned
- 15.59 pruned
- 16.25 pruned
- 16.91 pruned
- 17.57 pruned
- 17.90 pruned
- 18.23 pruned
- 18.90 pruned
- 19.56 pruned
- 19.89 pruned
- 20.55 pruned
- 20.88 pruned
- 21.54 pruned
- 21.87 pruned
- 21.88 pruned
- 22.54 pruned
- 22.87 pruned
- 23.53 pruned
- 24.20 pruned
- 24.21 pruned
- 24.87 pruned
- 25.20 pruned
- 25.86 pruned
- 26.19 pruned
- 26.20 pruned
- 26.86 pruned
- 27.19 pruned
- 27.52 pruned
- 28.18 pruned
- 28.51 pruned
- 29.18 pruned
- 29.51 pruned
- 29.52 pruned
- 29.85 pruned
- 29.86 pruned
- 30.52 pruned
- 30.85 pruned
- 31.51 pruned
- 32.17 pruned
- 32.83 pruned
- 33.16 pruned
- 33.82 pruned
- 34.16 pruned
- 34.82 pruned
- 35.15 pruned
- 35.48 pruned
- 36.14 pruned
- 36.47 pruned
- 37.13 pruned
- 37.79 pruned
- 37.80 pruned
- 38.13 pruned
- 38.79 pruned
- 38.80 pruned
- 39.13 pruned
- 39.15 pruned
- 39.81 pruned
- 40.14 pruned
- 40.47 pruned
- 40.48 pruned
- 41.14 pruned
- 41.47 pruned
- 41.80 pruned
- 41.81 pruned
- 42.47 pruned
- 43.13 pruned
- 43.46 pruned
- 43.79 pruned
- 44.46 pruned
- 44.79 pruned
- 44.80 pruned
- 45.46 pruned
- 45.79 pruned
- 45.80 pruned
- 46.46 pruned
- 46.79 pruned
- 47.12 pruned
- 47.78 pruned
- 47.79 pruned
- 47.80 pruned
- 48.13 pruned
- 48.79 pruned
- 49.13 pruned
- 49.79 pruned
- 49.80 pruned
- 50.46 pruned
-
- Test: average loss: 1.6824, accuracy: 6513/10000 (65%)
-
- finetune
- Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
- Test: average loss: 0.0324, accuracy: 9889/10000 (99%)
可以看到,剪枝完成后直接测试准确率只有65%非常低,重新对weight中的非零参数训练一次后立马接近之前的准确率。