• 用于神经网络的FLOP和Params计算工具


    用于神经网络的FLOP和Params计算工具

    1. FlopCountAnalysis

    pip install fvcore
    
    • 1
    import torch
    from torchvision.models import resnet152, resnet18
    from fvcore.nn import FlopCountAnalysis, parameter_count_table
    
    model = resnet152(num_classes=1000)
    
    tensor = (torch.rand(1, 3, 224, 224),)
    
    #分析FLOPs
    flops = FlopCountAnalysis(model, tensor)
    print("FLOPs: ", flops.total())
    
    def print_model_parm_nums(model):
        total = sum([param.nelement() for param in model.parameters()])
        print('  + Number of params: %.2fM' % (total / 1e6))
    
    print_model_parm_nums(model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    2. flopth

    https://github.com/vra/flopth

    pip install flopth 
    
    • 1

    Running on models in torchvision.models

    $ flopth -m alexnet 
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | module_name   | module_type       | in_shape    | out_shape   | params   | params_percent   | params_percent_vis             | flops    | flops_percent   | flops_percent_vis   |
    +===============+===================+=============+=============+==========+==================+================================+==========+=================+=====================+
    | features.0    | Conv2d            | (3,224,224) | (64,55,55)  | 23.296K  | 0.0381271%       |                                | 70.4704M | 9.84839%        | ####                |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.1    | ReLU              | (64,55,55)  | (64,55,55)  | 0.0      | 0.0%             |                                | 193.6K   | 0.027056%       |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.2    | MaxPool2d         | (64,55,55)  | (64,27,27)  | 0.0      | 0.0%             |                                | 193.6K   | 0.027056%       |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.3    | Conv2d            | (64,27,27)  | (192,27,27) | 307.392K | 0.50309%         |                                | 224.089M | 31.3169%        | ###############     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.4    | ReLU              | (192,27,27) | (192,27,27) | 0.0      | 0.0%             |                                | 139.968K | 0.0195608%      |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.5    | MaxPool2d         | (192,27,27) | (192,13,13) | 0.0      | 0.0%             |                                | 139.968K | 0.0195608%      |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.6    | Conv2d            | (192,13,13) | (384,13,13) | 663.936K | 1.08662%         |                                | 112.205M | 15.6809%        | #######             |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.7    | ReLU              | (384,13,13) | (384,13,13) | 0.0      | 0.0%             |                                | 64.896K  | 0.00906935%     |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.8    | Conv2d            | (384,13,13) | (256,13,13) | 884.992K | 1.44841%         |                                | 149.564M | 20.9018%        | ##########          |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.9    | ReLU              | (256,13,13) | (256,13,13) | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.10   | Conv2d            | (256,13,13) | (256,13,13) | 590.08K  | 0.965748%        |                                | 99.7235M | 13.9366%        | ######              |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.11   | ReLU              | (256,13,13) | (256,13,13) | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | features.12   | MaxPool2d         | (256,13,13) | (256,6,6)   | 0.0      | 0.0%             |                                | 43.264K  | 0.00604624%     |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | avgpool       | AdaptiveAvgPool2d | (256,6,6)   | (256,6,6)   | 0.0      | 0.0%             |                                | 9.216K   | 0.00128796%     |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.0  | Dropout           | (9216)      | (9216)      | 0.0      | 0.0%             |                                | 0.0      | 0.0%            |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.1  | Linear            | (9216)      | (4096)      | 37.7528M | 61.7877%         | ############################## | 37.7487M | 5.27547%        | ##                  |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.2  | ReLU              | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 4.096K   | 0.000572425%    |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.3  | Dropout           | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 0.0      | 0.0%            |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.4  | Linear            | (4096)      | (4096)      | 16.7813M | 27.4649%         | #############                  | 16.7772M | 2.34465%        | #                   |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.5  | ReLU              | (4096)      | (4096)      | 0.0      | 0.0%             |                                | 4.096K   | 0.000572425%    |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    | classifier.6  | Linear            | (4096)      | (1000)      | 4.097M   | 6.70531%         | ###                            | 4.096M   | 0.572425%       |                     |
    +---------------+-------------------+-------------+-------------+----------+------------------+--------------------------------+----------+-----------------+---------------------+
    
    
    FLOPs: 715.553M
    Params: 61.1008M
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    Running on custom models

    # file path: /tmp/my_model.py
    # model name:  MyModel
    import torch.nn as nn
    
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
            self.conv2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
            self.conv3 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
            self.conv4 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
    
        def forward(self, x1):
            x1 = self.conv1(x1)
            x1 = self.conv2(x1)
            x1 = self.conv3(x1)
            x1 = self.conv4(x1)
            return x1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    $ flopth -m MyModel -p /tmp/my_model.py -i 3 224 224
    +---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
    | module_name   | module_type   | in_shape    | out_shape   |   params | params_percent   | params_percent_vis   | flops    | flops_percent   | flops_percent_vis   |
    +===============+===============+=============+=============+==========+==================+======================+==========+=================+=====================+
    | conv1         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
    +---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
    | conv2         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
    +---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
    | conv3         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
    +---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
    | conv4         | Conv2d        | (3,224,224) | (3,224,224) |       84 | 25.0%            | ############         | 4.21478M | 25.0%           | ############        |
    +---------------+---------------+-------------+-------------+----------+------------------+----------------------+----------+-----------------+---------------------+
    
    FLOPs: 16.8591M
    Params: 336.0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    3. calflops

    https://github.com/MrYxJ/calculate-flops.pytorch/tree/main

    pip install calflops
    
    • 1
    from calflops import calculate_flops
    from torchvision import models
    
    model = models.alexnet()
    batch_size = 1
    input_shape = (batch_size, 3, 224, 224)
    flops, macs, params = calculate_flops(model=model, 
                                          input_shape=input_shape,
                                          output_as_string=True,
                                          output_precision=4)
    print("Alexnet FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))
    #Alexnet FLOPs:4.2892 GFLOPS   MACs:2.1426 GMACs   Params:61.1008 M 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    1. from thop import profile

    https://github.com/Lyken17/pytorch-OpCounter

    pip install thop
    
    • 1
    from torchvision.models import resnet50
    from thop import profile
    model = resnet50()
    input = torch.randn(1, 3, 224, 224)
    macs, params = profile(model, inputs=(input, ))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    class YourModule(nn.Module):
        # your definition
    def count_your_model(model, x, y):
        # your rule here
    
    input = torch.randn(1, 3, 224, 224)
    macs, params = profile(model, inputs=(input, ), 
                            custom_ops={YourModule: count_your_model})
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
  • 相关阅读:
    电脑硬件销售的设计与实现
    集合类不安全
    PaddleNLP学习日记(一)CBLUE医疗文本分类
    计算结构体大小(内存对齐原则)struct、union、class
    讲解 CSS 过渡和动画 — transition/animation (很全面)
    SpringSecurity框架
    计算机网络:随机访问介质访问控制之CSMA/CA协议
    【面试经典150 | 栈】简化路径
    Minecraft 的元宇宙进化 ?MineDojo 实现操作角色探索程序生成的 3D 世界
    算法导论 ——分治中求解递归式的三种方法
  • 原文地址:https://blog.csdn.net/wp133716/article/details/134509989