• 统计官方模型的参数量和计算量


    ModelParams(M)FLOPs(G)
    alexnet61.100.71
    densenet1217.982.90
    densenet16128.687.85
    densenet16914.153.44
    densenet20120.014.39
    googlenet6.621.51
    inception_v323.835.75
    mnasnet0_52.220.12
    mnasnet0_753.170.23
    mnasnet1_04.380.34
    mnasnet1_36.280.56
    mobilenet_v23.500.33
    mobilenet_v3_large5.480.23
    mobilenet_v3_small2.540.06
    resnet10144.557.87
    resnet15260.1911.60
    resnet1811.691.82
    resnet3421.803.68
    resnet5025.564.13
    resnext101_32x8d88.7916.54
    resnext50_32x4d25.034.29
    shufflenet_v2_x0_51.370.04
    shufflenet_v2_x1_02.280.15
    shufflenet_v2_x1_53.500.31
    shufflenet_v2_x2_07.390.60
    squeezenet1_01.250.82
    squeezenet1_11.240.35
    vgg11132.867.61
    vgg11_bn132.877.64
    vgg13133.0511.31
    vgg13_bn133.0511.36
    vgg16138.3615.47
    vgg16_bn138.3715.52
    vgg19143.6719.63
    vgg19_bn143.6819.69
    wide_resnet101_2126.8922.84
    wide_resnet50_268.8811.46
    import torch
    from torchvision import models
    
    model_names = sorted(
        name
        for name in models.__dict__
        if name.islower()
        and not name.startswith("__")  # and "inception" in name
        and callable(models.__dict__[name])
    )
    
    device = "cpu"
    
    
    for name in model_names:
        print("======================="+name+"=========================")
        model = models.__dict__[name]().to(device)
        dsize = (1, 3, 224, 224)
        size_ = (3, 224, 224)
        if "inception" in name:
            dsize = (1, 3, 299, 299)
            size_ = (3, 224, 224)
    
        '''
        # thop,功能单一,不推荐
        from thop.profile import profile
        inputs = torch.randn(dsize).to(device)
        total_ops, total_params = profile(model, (inputs,), verbose=False)
        print(
            "%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
        )
        '''
    
        '''
        # 直接测量参数量,不能统计计算量。
        total_params = sum(param.numel() for param in model.parameters())
        print(total_params / (1000 ** 2))
        '''
    
        # torchstat,打印内容最详细,但没打印层名,且支持某些层的种类有限(https://github.com/Swall0w/torchstat/blob/master/detail.md)。对于包含这些层的网络,计算结果比实际偏低。测量的是GPU模型。
        from torchstat import stat
        stat(model.to(device), size_)
        
        # torchinfo。框架比较新,还在更新中,各层的名称比较明确,但打印信息不如torchstat全面。测量的是CPU模型。
        from torchinfo import summary
        summary(model, input_size=dsize)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
  • 相关阅读:
    JavaScript 40 JavaScript For In
    Python selenium模块的常用方法【更新中】
    Real-Time Rendering——18.4 Optimization优化
    计算机网络的基础知识
    如何像开发人员一样思考_成为一个问题解决者
    Enviro 3 - Sky and Weather
    绿色债券数据集2016-2021(含交易代码、债券简称、发行规模&期限等多指标数据)
    vue实现el-menu与el-tabs联动
    形象谈JVM-第二章-认识编译器
    c#把DataTable的数据存到Text文件
  • 原文地址:https://blog.csdn.net/tfcy694/article/details/127984916