| Model | Params(M) | FLOPs(G) |
|---|
| alexnet | 61.10 | 0.71 |
| densenet121 | 7.98 | 2.90 |
| densenet161 | 28.68 | 7.85 |
| densenet169 | 14.15 | 3.44 |
| densenet201 | 20.01 | 4.39 |
| googlenet | 6.62 | 1.51 |
| inception_v3 | 23.83 | 5.75 |
| mnasnet0_5 | 2.22 | 0.12 |
| mnasnet0_75 | 3.17 | 0.23 |
| mnasnet1_0 | 4.38 | 0.34 |
| mnasnet1_3 | 6.28 | 0.56 |
| mobilenet_v2 | 3.50 | 0.33 |
| mobilenet_v3_large | 5.48 | 0.23 |
| mobilenet_v3_small | 2.54 | 0.06 |
| resnet101 | 44.55 | 7.87 |
| resnet152 | 60.19 | 11.60 |
| resnet18 | 11.69 | 1.82 |
| resnet34 | 21.80 | 3.68 |
| resnet50 | 25.56 | 4.13 |
| resnext101_32x8d | 88.79 | 16.54 |
| resnext50_32x4d | 25.03 | 4.29 |
| shufflenet_v2_x0_5 | 1.37 | 0.04 |
| shufflenet_v2_x1_0 | 2.28 | 0.15 |
| shufflenet_v2_x1_5 | 3.50 | 0.31 |
| shufflenet_v2_x2_0 | 7.39 | 0.60 |
| squeezenet1_0 | 1.25 | 0.82 |
| squeezenet1_1 | 1.24 | 0.35 |
| vgg11 | 132.86 | 7.61 |
| vgg11_bn | 132.87 | 7.64 |
| vgg13 | 133.05 | 11.31 |
| vgg13_bn | 133.05 | 11.36 |
| vgg16 | 138.36 | 15.47 |
| vgg16_bn | 138.37 | 15.52 |
| vgg19 | 143.67 | 19.63 |
| vgg19_bn | 143.68 | 19.69 |
| wide_resnet101_2 | 126.89 | 22.84 |
| wide_resnet50_2 | 68.88 | 11.46 |
import torch
from torchvision import models
model_names = sorted(
name
for name in models.__dict__
if name.islower()
and not name.startswith("__")
and callable(models.__dict__[name])
)
device = "cpu"
for name in model_names:
print("======================="+name+"=========================")
model = models.__dict__[name]().to(device)
dsize = (1, 3, 224, 224)
size_ = (3, 224, 224)
if "inception" in name:
dsize = (1, 3, 299, 299)
size_ = (3, 224, 224)
'''
# thop,功能单一,不推荐
from thop.profile import profile
inputs = torch.randn(dsize).to(device)
total_ops, total_params = profile(model, (inputs,), verbose=False)
print(
"%s | %.2f | %.2f" % (name, total_params / (1000 ** 2), total_ops / (1000 ** 3))
)
'''
'''
# 直接测量参数量,不能统计计算量。
total_params = sum(param.numel() for param in model.parameters())
print(total_params / (1000 ** 2))
'''
from torchstat import stat
stat(model.to(device), size_)
from torchinfo import summary
summary(model, input_size=dsize)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46