用torchinfo库的sumary函数可以打印模型信息,示例如下。注意输入张量中不需要batch维度。
- import numpy as np
-
- import torch
- import torchvision
- import torchinfo
-
- model = torchvision.models.resnet50(pretrained=True)
- torchinfo.summary(model, (3, 224, 224), batch_dim=0,
- col_names=('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose=1
- )