• 掌握Pytorch模型 压缩 裁剪与量化


    在深度学习模型的搭建和部署中,我们需要考虑到模型的权重个数、模型权重大小、模型推理速度和计算量。本文将分享在Pytorch中进行模型压缩、裁剪和量化的教程。

    权重压缩

    模型在训练时使用的模型权重类型为float32,而在模型部署时则不需要高的数据精度。可以将类型转换为float16进行保存,这样可以降低45%左右的权重大小。

    • 步骤1:训练并保存模型
    import timm
    model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
    model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
    
    • 1
    • 2
    • 3
    • 步骤2:转换数据类型并存储
    params = torch.load('model_mobilevit_xxs.pth') # float32
    for key in params.keys():
        params[key] = params[key].half() # float16
    
    torch.save(params, 'model_mobilevit_xxs_half.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5

    权重裁剪

    在模型训练完成后可以考虑对冗余的权重进行裁剪,有以下几种裁剪方法:

    • 按照比例随机裁剪
    • 按照权重大小裁剪

    https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

    使用的案例代码如下:

    import torch.nn.utils.prune as prune
    import numpy as np
    
    model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
    model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
    
    # 选中需要裁剪的层
    module = model.head.fc
    
    # random_unstructured裁剪
    prune.random_unstructured(module, name="weight", amount=0.3)
    
    # l1_unstructured裁剪
    prune.l1_unstructured(module, name="weight", amount=0.3)
    
    # ln_structured裁剪
    prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    在使用权重裁剪需要注意:

    • 权重裁剪并不会改变模型的权重大小,只是增加了稀疏性;
    • 权重裁剪并不会减少模型的预测速度,只是减少了计算量;
    • 权重裁剪的参数比例会对模型精度有影响,需要测试和验证;

    权重量化

    32-bit的乘加变成了8-bit的乘加,模型权重大小减少,对内存的要求降低了。

    https://pytorch.org/docs/stable/quantization.html

    Eager Mode Quantization

    import torch
    
    # define a floating point model
    class M(torch.nn.Module):
        def __init__(self):
            super(M, self).__init__()
            self.fc1 = torch.nn.Linear(100, 40)
            self.fc2 = torch.nn.Linear(1000, 400)
    
        def forward(self, x):
            x = self.fc1(x)
            return x
    
    # create a model instance
    model_fp32 = M()
    torch.save(model_fp32.state_dict(), 'tmp_float32.pth')
    
    # create a quantized model instance
    model_int8 = torch.quantization.quantize_dynamic(
        model_fp32,  # the original model
        {torch.nn.Linear},  # a set of layers to dynamically quantize
        dtype=torch.qint8)  # the target dtype for quantized weights
    
    # run the model
    input_fp32 = torch.randn(4, 4, 4, 4)
    res = model_int8(input_fp32)
    torch.save(model_int8.state_dict(), 'tmp_int8.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    Post Training Static Quantization

    import torch
    
    # define a floating point model where some layers could be statically quantized
    class M(torch.nn.Module):
        def __init__(self):
            super(M, self).__init__()
            # QuantStub converts tensors from floating point to quantized
            self.quant = torch.quantization.QuantStub()
            self.conv = torch.nn.Conv2d(1, 100, 1)
            self.relu = torch.nn.ReLU()
            self.fc = torch.nn.Linear(100, 10)
            # DeQuantStub converts tensors from quantized to floating point
            self.dequant = torch.quantization.DeQuantStub()
    
        def forward(self, x):
            # manually specify where tensors will be converted from floating
            # point to quantized in the quantized model
            x = self.quant(x)
            x = self.conv(x)
            x = self.relu(x)
            # manually specify where tensors will be converted from quantized
            # to floating point in the quantized model
            x = self.dequant(x)
            return x
    
    # create a model instance
    model_fp32 = M()
    torch.save(model_fp32.state_dict(), 'tmp_float32.pth')
    
    model_fp32.eval()
    
    model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    
    model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
    model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
    
    input_fp32 = torch.randn(4, 1, 4, 4)
    model_fp32_prepared(input_fp32)
    
    model_int8 = torch.quantization.convert(model_fp32_prepared)
    res = model_int8(input_fp32)
    torch.save(model_int8.state_dict(), 'tmp_int8.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    Pytorch暂时的量化操作还不是很完善,可能存在只能在CPU上运行,且速度变慢的情况。如果有量化需求,推荐使用tensorrt和GPU一起使用。

  • 相关阅读:
    计算机网络体系结构
    Android FloatingActionButton 使用神坑记录
    DragGAN应运而生,未来在4G视频上都可能利用拖拽式编辑
    OAuth2.0和1.0的区别
    Ubuntu22.04 安装配置流水账
    SpringBoot热部署和整合Mybatis
    软件设计师2011上午题基础知识(易错整理)
    电子学:第011课——实验 10:晶体管开关
    软件课程设计--仓库管理系统
    Ansible在macOS上的安装部署
  • 原文地址:https://blog.csdn.net/Just_Finlay/article/details/126309039