• 模型剪枝介绍


    Ref:https://www.cnblogs.com/the-art-of-ai/p/17500399.html

    1、背景介绍

            深度学习模型在图像识别、自然语言处理、语音识别等领域取得了显著的成果,但是这些模型往往需要大量的计算资源和存储空间。尤其是在移动设备和嵌入式系统等资源受限的环境下,这些模型的体积和计算复杂度往往成为了限制其应用的瓶颈。因此,如何在保持模型准确率的同时,尽可能地减少模型的体积和计算复杂度,成为了一个重要的研究方向。

            模型剪枝技术就是解决这个问题的一种有效方法。它通过对深度学习模型进行结构优化和参数削减,使得模型在保持准确率的前提下,具有更小的体积和更快的运行速度,从而更好地适应不同的任务和环境

    2、基本原理

            模型剪枝技术是指对深度学习模型进行结构优化和参数削减的一种技术。剪枝技术可以分为结构剪枝参数剪枝两种形式。

            结构剪枝是指从深度学习模型中删除一些不必要的结构单元,如神经元、卷积核、层等,以减少模型的计算复杂度和存储空间。常见的结构剪枝方法包括:通道剪枝、层剪枝、节点剪枝、过滤器剪枝等。

            参数剪枝是指从深度学习模型中删除一些不必要的权重参数,以减少模型的存储空间和计算复杂度,同时保持模型的准确率。常见的参数剪枝方法包括:L1正则化、L2正则化、排序剪枝、局部敏感哈希剪枝等。

    3、 技术原理

            模型剪枝技术的核心思想是在保持模型准确率的前提下,尽可能地减少模型的存储空间和计算复杂度。由于深度学习模型中的神经元、卷积核、权重参数等结构单元和参数往往存在冗余和不必要的部分,因此可以通过剪枝技术来减少这些冗余部分,从而达到减小模型体积和计算复杂度的效果。

            具体来说,模型剪枝技术的实现可以分为以下几个步骤:

            (1)初始化模型;首先,初始化一个深度学习模型并进行训练,获得一个基准模型;

            (2)选择剪枝量化方法和策略;根据具体的应用场景和需求,选择合适的剪枝方法和策略;常见的简直方法包括:结构剪枝和参数剪枝;常见的策略包括:全局剪枝和迭代剪枝;

            (3)剪枝模型;基于选择的剪枝方法和策略,对深度学习模型进行剪枝操作;具体来说删除一些不必要的结构单元和权重参数,或者将他们设置为0或者很小的值;

            (4)重新训练模型;剪枝操作可能会导致模型的准确率下降;因此需要重新对剪枝后的模型进行训练;以恢复模型的准确率;

            (5)微调模型;重新训练后,对模型进行微调;进一步提高模型的准确率;

           代码:

    1. import torch
    2. import torch.nn as nn
    3. import torch.optim as optim
    4. import torch.nn.functional as F
    5. from torchvision import datasets, transforms
    6. # 定义一个简单的卷积神经网络
    7. class SimpleCNN(nn.Module):
    8. def __init__(self):
    9. super(SimpleCNN, self).__init__()
    10. self.conv1 = nn.Conv2d(1, 4, kernel_size=3, padding=1) # 4个输出通道
    11. self.conv2 = nn.Conv2d(4, 8, kernel_size=3, padding=1) # 8个输出通道
    12. self.fc1 = nn.Linear(8 * 7 * 7, 64)
    13. self.fc2 = nn.Linear(64, 10)
    14. def forward(self, x):
    15. x = F.relu(self.conv1(x)) # 卷积层1 + ReLU激活函数
    16. x = F.max_pool2d(x, 2) # 最大池化层,池化核大小为2x2
    17. x = F.relu(self.conv2(x)) # 卷积层2 + ReLU激活函数
    18. x = F.max_pool2d(x, 2) # 最大池化层,池化核大小为2x2
    19. x = x.view(x.size(0), -1) # 展平操作,将多维张量展平成一维
    20. x = F.relu(self.fc1(x)) # 全连接层1 + ReLU激活函数
    21. x = self.fc2(x) # 全连接层2,输出10个类别
    22. return x
    23. # 实例化模型
    24. model = SimpleCNN()
    25. # 打印剪枝前的模型结构
    26. print("Model before pruning:")
    27. print(model)
    28. # 加载数据
    29. transform = transforms.Compose([
    30. transforms.ToTensor(), # 转换为张量
    31. transforms.Normalize((0.1307,), (0.3081,)) # 归一化
    32. ])
    33. train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) # 加载训练数据集
    34. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) # 创建数据加载器
    35. # 定义损失函数和优化器
    36. criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
    37. optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
    38. # 训练模型
    39. model.train() # 将模型设置为训练模式
    40. for epoch in range(1): # 训练一个epoch
    41. running_loss = 0.0
    42. for data, target in train_loader:
    43. optimizer.zero_grad() # 清零梯度
    44. outputs = model(data) # 前向传播
    45. loss = criterion(outputs, target) # 计算损失
    46. loss.backward() # 反向传播
    47. optimizer.step() # 更新参数
    48. running_loss += loss.item() * data.size(0) # 累加损失
    49. epoch_loss = running_loss / len(train_loader.dataset) # 计算平均损失
    50. print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
    51. # 通道剪枝
    52. # 获取卷积层的权重
    53. conv1_weights = model.conv1.weight.data.abs().sum(dim=[1, 2, 3]) # 计算每个通道的L1范数
    54. # 按照L1范数对通道进行排序
    55. sorted_channels = torch.argsort(conv1_weights)
    56. # 选择需要删除的通道
    57. num_prune = 2 # 假设我们要删除2个通道
    58. channels_to_prune = sorted_channels[:num_prune]
    59. print("Channels to prune:", channels_to_prune)
    60. # 删除指定通道的权重和偏置
    61. pruned_weights = torch.index_select(model.conv1.weight.data, 0, sorted_channels[num_prune:]) # 获取保留的权重
    62. pruned_bias = torch.index_select(model.conv1.bias.data, 0, sorted_channels[num_prune:]) # 获取保留的偏置
    63. # 创建一个新的卷积层,并将剪枝后的权重和偏置赋值给它
    64. model.conv1 = nn.Conv2d(in_channels=1, out_channels=4 - num_prune, kernel_size=3, padding=1)
    65. model.conv1.weight.data = pruned_weights
    66. model.conv1.bias.data = pruned_bias
    67. # 同时我们还需要调整conv2层的输入通道
    68. # 获取conv2层的权重并调整其输入通道
    69. conv2_weights = model.conv2.weight.data[:, sorted_channels[num_prune:], :, :] # 调整输入通道的权重
    70. # 创建一个新的卷积层,并将剪枝后的权重赋值给它
    71. model.conv2 = nn.Conv2d(in_channels=4 - num_prune, out_channels=8, kernel_size=3, padding=1)
    72. model.conv2.weight.data = conv2_weights
    73. # 打印剪枝后的模型结构
    74. print("Model after pruning:")
    75. print(model)
    76. # 定义新的优化器
    77. optimizer = optim.Adam(model.parameters(), lr=0.001)
    78. # 重新训练模型
    79. model.train() # 将模型设置为训练模式
    80. for epoch in range(1): # 训练一个epoch
    81. running_loss = 0.0
    82. for data, target in train_loader:
    83. optimizer.zero_grad() # 清零梯度
    84. outputs = model(data) # 前向传播
    85. loss = criterion(outputs, target) # 计算损失
    86. loss.backward() # 反向传播
    87. optimizer.step() # 更新参数
    88. running_loss += loss.item() * data.size(0) # 累加损失
    89. epoch_loss = running_loss / len(train_loader.dataset) # 计算平均损失
    90. print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
    91. # 加载测试数据
    92. test_dataset = datasets.MNIST('./data', train=False, transform=transform) # 加载测试数据集
    93. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False) # 创建数据加载器
    94. # 评估模型
    95. model.eval() # 将模型设置为评估模式
    96. correct = 0
    97. total = 0
    98. with torch.no_grad(): # 关闭梯度计算
    99. for data, target in test_loader:
    100. outputs = model(data) # 前向传播
    101. _, predicted = torch.max(outputs.data, 1) # 获取预测结果
    102. total += target.size(0) # 总样本数
    103. correct += (predicted == target).sum().item() # 正确预测的样本数
    104. print(f'Accuracy: {100 * correct / total}%') # 打印准确率

            为了提高剪枝技术的性能和效率,可以考虑以下几个方面的优化:

    • 选择合适的剪枝策略和剪枝算法,以提高剪枝的效果和准确率。

    • 对剪枝后的模型进行微调或增量学习,以进一步提高模型的准确率和性能。

    • 使用并行计算和分布式计算技术,以加速剪枝和训练过程。

  • 相关阅读:
    模板引擎小结-原理
    h5修改钉钉双标题栏问题
    [ERROR] mariadbd: The table 'INNODB_BUFFER_PAGE' is full
    Docker 容器化(初学者的分享)
    B40 - 基于STM32单片机的电热蚊香蓝牙控制系统
    主流框架选择:React、Angular、Vue的详细比较
    Maven的聚合 继承 属性 版本管理 多环境资源配置 跳过测试
    Flink技术灵活使用总结(一)状态与状态后端
    【Egg从基础到进阶】二:安装本地Mysql
    2022前端面试题整理总结~
  • 原文地址:https://blog.csdn.net/qq_43642885/article/details/140354729