• 网络模型的参数量和FLOPs的计算 Pytorch


    目录

    1、torchstat 

    2、thop

    3、fvcore 

    4、flops_counter

    5、自定义统计函数


    FLOPS和FLOPs的区别:

    • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
    • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

    在介绍torchstat包和thop包之前,先总结一下:

    • torchstat包可以统计卷积神经网络和全连接神经网络的参数和计算量。
    • thop包可以统计统计卷积神经网络、全连接神经网络以及循环神经网络的参数和计算量,程序示例等详见下文。

    1、torchstat 

    pip install torchstat -i https://pypi.tuna.tsinghua.edu.cn/simple

    在实际操作中,我们可以调用torchstat包,帮助我们统计模型的parameters和FLOPs。如果不修改这个包里面的一些代码,那么这个包只适用于输入为3通道的图像的模型。

    1. import torch
    2. import torch.nn as nn
    3. from torchstat import stat
    4. class Simple(nn.Module):
    5. def __init__(self):
    6. super().__init__()
    7. self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
    8. self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)
    9. def forward(self, x):
    10. x = self.conv1(x)
    11. x = self.conv2(x)
    12. return x
    13. model = Simple()
    14. stat(model, (3, 244, 244)) # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size

     如果把torchstat包中的一行程序进行一点点改动,那么这个包可以用来统计全连接神经网络的参数量和计算量。当然手动计算全连接神经网络的参数量和计算量也很快 =_= 。进入torchstat源代码之后,如下图所示,注释掉圈红的地方,就可以用torchstat包统计全连接神经网络的参数量和计算量了。

    2、thop

    pip install thop -i https://pypi.tuna.tsinghua.edu.cn/simple
    1. import torch
    2. import torch.nn as nn
    3. from thop import profile
    4. class Simple(nn.Module):
    5. def __init__(self):
    6. super().__init__()
    7. self.fc1 = nn.Linear(10, 10)
    8. def forward(self, x):
    9. x = self.fc1(x)
    10. return x
    11. net = Simple()
    12. input = torch.randn(1, 10) # batchsize=1, 输入向量长度为10
    13. macs, params = profile(net, inputs=(input, ))
    14. print(' FLOPs: ', macs*2) # 一般来讲,FLOPs是macs的两倍
    15. print('params: ', params)

    3、fvcore 

    pip install fvcore -i https://pypi.tuna.tsinghua.edu.cn/simple

    用它比较好

    1. import torch
    2. from torchvision.models import resnet50
    3. from fvcore.nn import FlopCountAnalysis, parameter_count_table
    4. # 创建resnet50网络
    5. model = resnet50(num_classes=1000)
    6. # 创建输入网络的tensor
    7. tensor = (torch.rand(1, 3, 224, 224),)
    8. # 分析FLOPs
    9. flops = FlopCountAnalysis(model, tensor)
    10. print("FLOPs: ", flops.total())
    11. # 分析parameters
    12. print(parameter_count_table(model))

     终端输出结果如下,FLOPs为4089184256,模型参数数量约为25.6M(这里的参数数量和我自己计算的有些出入,主要是在BN模块中,这里只计算了beta和gamma两个训练参数,没有统计moving_mean和moving_var两个参数),具体可以看下我在官方提的issue。
    通过终端打印的信息我们可以发现在计算FLOPs时并没有包含BN层,池化层还有普通的add操作(我发现计算FLOPs时并没有统一的规定,在github上看的计算FLOPs项目基本每个都不同,但计算出来的结果大同小异)。

    注意:在使用fvcore模块计算模型的flops时,遇到了问题,记录一下解决方案。首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要手动进行强制转换。修改如下:

    4、flops_counter

    pip install ptflops -i https://pypi.tuna.tsinghua.edu.cn/simple

    用它也很好,结果和fvcore一样

    1. from ptflops import get_model_complexity_info
    2. macs, params = get_model_complexity_info(model, (112, 9, 9), as_strings=True,
    3. print_per_layer_stat=True, verbose=True)
    4. print('{:<30} {:<8}'.format('Computational complexity: ', macs))
    5. print('{:<30} {:<8}'.format('Number of parameters: ', params))

    5、自定义统计函数

    1. import torch
    2. import numpy as np
    3. def calc_flops(model, input):
    4. def conv_hook(self, input, output):
    5. batch_size, input_channels, input_height, input_width = input[0].size()
    6. output_channels, output_height, output_width = output[0].size()
    7. kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (
    8. 2 if multiply_adds else 1)
    9. bias_ops = 1 if self.bias is not None else 0
    10. params = output_channels * (kernel_ops + bias_ops)
    11. flops = batch_size * params * output_height * output_width
    12. list_conv.append(flops)
    13. def linear_hook(self, input, output):
    14. batch_size = input[0].size(0) if input[0].dim() == 2 else 1
    15. num_steps = input[0].size(0)
    16. weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
    17. bias_ops = self.bias.nelement() if self.bias is not None else 0
    18. flops = batch_size * (weight_ops + bias_ops)
    19. flops *= num_steps
    20. list_linear.append(flops)
    21. def fsmn_hook(self, input, output):
    22. batch_size = input[0].size(0) if input[0].dim() == 2 else 1
    23. weight_ops = self.filter.nelement() * (2 if multiply_adds else 1)
    24. num_steps = input[0].size(0)
    25. flops = num_steps * weight_ops
    26. flops *= batch_size
    27. list_fsmn.append(flops)
    28. def gru_cell(input_size, hidden_size, bias=True):
    29. total_ops = 0
    30. # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
    31. # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
    32. state_ops = (hidden_size + input_size) * hidden_size + hidden_size
    33. if bias:
    34. state_ops += hidden_size * 2
    35. total_ops += state_ops * 2
    36. # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
    37. total_ops += (hidden_size + input_size) * hidden_size + hidden_size
    38. if bias:
    39. total_ops += hidden_size * 2
    40. # r hadamard : r * (~)
    41. total_ops += hidden_size
    42. # h' = (1 - z) * n + z * h
    43. # hadamard hadamard add
    44. total_ops += hidden_size * 3
    45. return total_ops
    46. def gru_hook(self, input, output):
    47. batch_size = input[0].size(0) if input[0].dim() == 2 else 1
    48. if self.batch_first:
    49. batch_size = input[0].size(0)
    50. num_steps = input[0].size(1)
    51. else:
    52. batch_size = input[0].size(1)
    53. num_steps = input[0].size(0)
    54. total_ops = 0
    55. bias = self.bias
    56. input_size = self.input_size
    57. hidden_size = self.hidden_size
    58. num_layers = self.num_layers
    59. total_ops = 0
    60. total_ops += gru_cell(input_size, hidden_size, bias)
    61. for i in range(num_layers - 1):
    62. total_ops += gru_cell(hidden_size, hidden_size, bias)
    63. total_ops *= batch_size
    64. total_ops *= num_steps
    65. list_lstm.append(total_ops)
    66. def lstm_cell(input_size, hidden_size, bias):
    67. total_ops = 0
    68. state_ops = (input_size + hidden_size) * hidden_size + hidden_size
    69. if bias:
    70. state_ops += hidden_size * 2
    71. total_ops += state_ops * 4
    72. total_ops += hidden_size * 3
    73. total_ops += hidden_size
    74. return total_ops
    75. def lstm_hook(self, input, output):
    76. batch_size = input[0].size(0) if input[0].dim() == 2 else 1
    77. if self.batch_first:
    78. batch_size = input[0].size(0)
    79. num_steps = input[0].size(1)
    80. else:
    81. batch_size = input[0].size(1)
    82. num_steps = input[0].size(0)
    83. total_ops = 0
    84. bias = self.bias
    85. input_size = self.input_size
    86. hidden_size = self.hidden_size
    87. num_layers = self.num_layers
    88. total_ops = 0
    89. total_ops += lstm_cell(input_size, hidden_size, bias)
    90. for i in range(num_layers - 1):
    91. total_ops += lstm_cell(hidden_size, hidden_size, bias)
    92. total_ops *= batch_size
    93. total_ops *= num_steps
    94. list_lstm.append(total_ops)
    95. def bn_hook(self, input, output):
    96. list_bn.append(input[0].nelement())
    97. def relu_hook(self, input, output):
    98. list_relu.append(input[0].nelement())
    99. def pooling_hook(self, input, output):
    100. batch_size, input_channels, input_height, input_width = input[0].size()
    101. output_channels, output_height, output_width = output[0].size()
    102. kernel_ops = self.kernel_size * self.kernel_size
    103. bias_ops = 0
    104. params = output_channels * (kernel_ops + bias_ops)
    105. flops = batch_size * params * output_height * output_width
    106. list_pooling.append(flops)
    107. def foo(net):
    108. childrens = list(net.children())
    109. if not childrens:
    110. print(net)
    111. if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
    112. net.register_forward_hook(conv_hook)
    113. # print('conv_hook_ready')
    114. if isinstance(net, torch.nn.Linear):
    115. net.register_forward_hook(linear_hook)
    116. # print('linear_hook_ready')
    117. if isinstance(net, torch.nn.BatchNorm2d):
    118. net.register_forward_hook(bn_hook)
    119. # print('batch_norm_hook_ready')
    120. if isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU):
    121. net.register_forward_hook(relu_hook)
    122. # print('relu_hook_ready')
    123. if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
    124. net.register_forward_hook(pooling_hook)
    125. # print('pooling_hook_ready')
    126. if isinstance(net, torch.nn.LSTM):
    127. net.register_forward_hook(lstm_hook)
    128. # print('lstm_hook_ready')
    129. if isinstance(net, torch.nn.GRU):
    130. net.register_forward_hook(gru_hook)
    131. # if isinstance(net, FSMNZQ):
    132. # net.register_forward_hook(fsmn_hook)
    133. # print('fsmn_hook_ready')
    134. return
    135. for c in childrens:
    136. foo(c)
    137. multiply_adds = False
    138. list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], []
    139. foo(model)
    140. _ = model(input)
    141. total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(
    142. list_lstm) + sum(list_fsmn))
    143. fsmn_flops = (sum(list_fsmn) + sum(list_linear))
    144. lstm_flops = sum(list_lstm)
    145. model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    146. params = sum([np.prod(p.size()) for p in model_parameters])
    147. print('The network has {} params.'.format(params))
    148. print(total_flops, fsmn_flops, lstm_flops)
    149. print(' + Number of FLOPs: %.2f M' % (total_flops / 1000 ** 2))
    150. return total_flops
    151. if __name__ == '__main__':
    152. from torchvision.models import resnet18
    153. model = resnet18(num_classes=1000)
    154. imput_size = torch.rand((1,3,224,224))
    155. calc_flops(model, imput_size)

  • 相关阅读:
    SSMP整合案例交互之在idea中利用vue和axios发送异步请求进行前后端调用
    集合贴——智能客服是什么
    python毕业设计项目源码选题(7)校园排课选课系统毕业设计毕设作品开题报告开题答辩PPT
    百度ERNIE 3.0——中文情感分析实战
    攻防世界心仪的公司
    论文学习记录随笔 多标签之GLOCAL
    vue中预览epub文件
    修改iframe生成的pdf的比例
    java培训技术自定义类型转换器示例
    qemu侧 网络包发送调试记录(二)
  • 原文地址:https://blog.csdn.net/qq_45100200/article/details/127728053