本文解释简单给一个模型列子记录如何计算该模型参数量与模型显存占用情况,该文直接调用torchvision库的模型文件构建模型model,在使用parameters()函数遍历,并在遍历情况下使用numel()函数记录模型参数量与显存占用。
代码如下:
import torchvision
mobilenet_v2=torchvision.models.mobilenet_v2()
if __name__ == '__main__':
mobilenet_v2 = torchvision.models.mobilenet_v2()
total = sum(p.numel() for p in mobilenet_v2.parameters()) # 统计个数
print("模型参数总量: %.2f million\t" % (total / 1e6), " 以float32模型内存占用:%.2f M" % (total * 4 / 1e6))
结果如下: