• Pytorch - 使用torchsummary/torchsummaryX/torchinfo库打印模型结构、输出维度和参数信息


    1 torchsummary/torchsummaryX

    torchsummary Github地址:https://github.com/sksq96/pytorch-summary

    torchsummaryX Github地址:https://github.com/nmhkahn/torchsummaryX

    torchinfo Github地址:https://github.com/TylerYep/torchinfo

    1.1 安装

    安装torchsummary

    pip install torchsummary
    
    • 1

    安装torchsummaryX

    pip install torchsummaryX
    
    • 1

    安装torchinfo
    pip

    pip install torchinfo
    
    • 1

    conda

    conda install -c conda-forge torchinfo
    
    • 1

    1.2 使用

    1.2.1 torchsummary的使用

    from torchvision import models
    from torchsummary import summary
    
    if __name__ == '__main__':
        resnet18 = models.resnet18().cuda() # 不加.cuda()会报错
        summary(resnet18, (3, 224, 224))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1         [-1, 64, 112, 112]           9,408
           BatchNorm2d-2         [-1, 64, 112, 112]             128
                  ReLU-3         [-1, 64, 112, 112]               0
             MaxPool2d-4           [-1, 64, 56, 56]               0
                Conv2d-5           [-1, 64, 56, 56]          36,864
           BatchNorm2d-6           [-1, 64, 56, 56]             128
                  ReLU-7           [-1, 64, 56, 56]               0
                Conv2d-8           [-1, 64, 56, 56]          36,864
           BatchNorm2d-9           [-1, 64, 56, 56]             128
                 ReLU-10           [-1, 64, 56, 56]               0
           BasicBlock-11           [-1, 64, 56, 56]               0
               Conv2d-12           [-1, 64, 56, 56]          36,864
          BatchNorm2d-13           [-1, 64, 56, 56]             128
                 ReLU-14           [-1, 64, 56, 56]               0
               Conv2d-15           [-1, 64, 56, 56]          36,864
          BatchNorm2d-16           [-1, 64, 56, 56]             128
                 ReLU-17           [-1, 64, 56, 56]               0
           BasicBlock-18           [-1, 64, 56, 56]               0
               Conv2d-19          [-1, 128, 28, 28]          73,728
          BatchNorm2d-20          [-1, 128, 28, 28]             256
                 ReLU-21          [-1, 128, 28, 28]               0
               Conv2d-22          [-1, 128, 28, 28]         147,456
          BatchNorm2d-23          [-1, 128, 28, 28]             256
               Conv2d-24          [-1, 128, 28, 28]           8,192
          BatchNorm2d-25          [-1, 128, 28, 28]             256
                 ReLU-26          [-1, 128, 28, 28]               0
           BasicBlock-27          [-1, 128, 28, 28]               0
               Conv2d-28          [-1, 128, 28, 28]         147,456
          BatchNorm2d-29          [-1, 128, 28, 28]             256
                 ReLU-30          [-1, 128, 28, 28]               0
               Conv2d-31          [-1, 128, 28, 28]         147,456
          BatchNorm2d-32          [-1, 128, 28, 28]             256
                 ReLU-33          [-1, 128, 28, 28]               0
           BasicBlock-34          [-1, 128, 28, 28]               0
               Conv2d-35          [-1, 256, 14, 14]         294,912
          BatchNorm2d-36          [-1, 256, 14, 14]             512
                 ReLU-37          [-1, 256, 14, 14]               0
               Conv2d-38          [-1, 256, 14, 14]         589,824
          BatchNorm2d-39          [-1, 256, 14, 14]             512
               Conv2d-40          [-1, 256, 14, 14]          32,768
          BatchNorm2d-41          [-1, 256, 14, 14]             512
                 ReLU-42          [-1, 256, 14, 14]               0
           BasicBlock-43          [-1, 256, 14, 14]               0
               Conv2d-44          [-1, 256, 14, 14]         589,824
          BatchNorm2d-45          [-1, 256, 14, 14]             512
                 ReLU-46          [-1, 256, 14, 14]               0
               Conv2d-47          [-1, 256, 14, 14]         589,824
          BatchNorm2d-48          [-1, 256, 14, 14]             512
                 ReLU-49          [-1, 256, 14, 14]               0
           BasicBlock-50          [-1, 256, 14, 14]               0
               Conv2d-51            [-1, 512, 7, 7]       1,179,648
          BatchNorm2d-52            [-1, 512, 7, 7]           1,024
                 ReLU-53            [-1, 512, 7, 7]               0
               Conv2d-54            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-55            [-1, 512, 7, 7]           1,024
               Conv2d-56            [-1, 512, 7, 7]         131,072
          BatchNorm2d-57            [-1, 512, 7, 7]           1,024
                 ReLU-58            [-1, 512, 7, 7]               0
           BasicBlock-59            [-1, 512, 7, 7]               0
               Conv2d-60            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-61            [-1, 512, 7, 7]           1,024
                 ReLU-62            [-1, 512, 7, 7]               0
               Conv2d-63            [-1, 512, 7, 7]       2,359,296
          BatchNorm2d-64            [-1, 512, 7, 7]           1,024
                 ReLU-65            [-1, 512, 7, 7]               0
           BasicBlock-66            [-1, 512, 7, 7]               0
    AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
               Linear-68                 [-1, 1000]         513,000
    ================================================================
    Total params: 11,689,512
    Trainable params: 11,689,512
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.57
    Forward/backward pass size (MB): 62.79
    Params size (MB): 44.59
    Estimated Total Size (MB): 107.96
    ----------------------------------------------------------------
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81

    1.2.2 torchsummaryX的使用

    import torch
    import torch.nn as nn
    from torchsummaryX import summary
    
    class Net(nn.Module):
        def __init__(self,
                     vocab_size=20, embed_dim=300,
                     hidden_dim=512, num_layers=2):
            super().__init__()
    
            self.hidden_dim = hidden_dim
            self.embedding = nn.Embedding(vocab_size, embed_dim)
            self.encoder = nn.LSTM(embed_dim, hidden_dim,
                                   num_layers=num_layers)
            self.decoder = nn.Linear(hidden_dim, vocab_size)
    
        def forward(self, x):
            embed = self.embedding(x)
            out, hidden = self.encoder(embed)
            out = self.decoder(out)
            out = out.view(-1, out.size(2))
            return out, hidden
    
    if __name__ == '__main__':
        inputs = torch.zeros((100, 1), dtype=torch.long)  # [length, batch_size]
        summary(Net(), inputs)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    输出

    ===========================================================
                Kernel Shape   Output Shape   Params  Mult-Adds
    Layer                                                      
    0_embedding    [300, 20]  [100, 1, 300]     6000       6000
    1_encoder              -  [100, 1, 512]  3768320    3760128
    2_decoder      [512, 20]   [100, 1, 20]    10260      10240
    -----------------------------------------------------------
                           Totals
    Total params          3784580
    Trainable params      3784580
    Non-trainable params        0
    Mult-Adds             3776368
    ===========================================================
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    1.2.3 torchinfo的使用

    from torchvision import models
    from torchinfo import summary
    
    if __name__ == '__main__':
        resnet18 = models.resnet18().cuda() # 不加.cuda()会报错
        summary(resnet18, input_size=(1, 3, 244, 244))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出

    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    ResNet                                   [1, 1000]                 --
    ├─Conv2d: 1-1                            [1, 64, 122, 122]         9,408
    ├─BatchNorm2d: 1-2                       [1, 64, 122, 122]         128
    ├─ReLU: 1-3                              [1, 64, 122, 122]         --
    ├─MaxPool2d: 1-4                         [1, 64, 61, 61]           --
    ├─Sequential: 1-5                        [1, 64, 61, 61]           --
    │    └─BasicBlock: 2-1                   [1, 64, 61, 61]           --
    │    │    └─Conv2d: 3-1                  [1, 64, 61, 61]           36,864
    │    │    └─BatchNorm2d: 3-2             [1, 64, 61, 61]           128
    │    │    └─ReLU: 3-3                    [1, 64, 61, 61]           --
    │    │    └─Conv2d: 3-4                  [1, 64, 61, 61]           36,864
    │    │    └─BatchNorm2d: 3-5             [1, 64, 61, 61]           128
    │    │    └─ReLU: 3-6                    [1, 64, 61, 61]           --
    │    └─BasicBlock: 2-2                   [1, 64, 61, 61]           --
    │    │    └─Conv2d: 3-7                  [1, 64, 61, 61]           36,864
    │    │    └─BatchNorm2d: 3-8             [1, 64, 61, 61]           128
    │    │    └─ReLU: 3-9                    [1, 64, 61, 61]           --
    │    │    └─Conv2d: 3-10                 [1, 64, 61, 61]           36,864
    │    │    └─BatchNorm2d: 3-11            [1, 64, 61, 61]           128
    │    │    └─ReLU: 3-12                   [1, 64, 61, 61]           --
    ├─Sequential: 1-6                        [1, 128, 31, 31]          --
    │    └─BasicBlock: 2-3                   [1, 128, 31, 31]          --
    │    │    └─Conv2d: 3-13                 [1, 128, 31, 31]          73,728
    │    │    └─BatchNorm2d: 3-14            [1, 128, 31, 31]          256
    │    │    └─ReLU: 3-15                   [1, 128, 31, 31]          --
    │    │    └─Conv2d: 3-16                 [1, 128, 31, 31]          147,456
    │    │    └─BatchNorm2d: 3-17            [1, 128, 31, 31]          256
    │    │    └─Sequential: 3-18             [1, 128, 31, 31]          8,448
    │    │    └─ReLU: 3-19                   [1, 128, 31, 31]          --
    │    └─BasicBlock: 2-4                   [1, 128, 31, 31]          --
    │    │    └─Conv2d: 3-20                 [1, 128, 31, 31]          147,456
    │    │    └─BatchNorm2d: 3-21            [1, 128, 31, 31]          256
    │    │    └─ReLU: 3-22                   [1, 128, 31, 31]          --
    │    │    └─Conv2d: 3-23                 [1, 128, 31, 31]          147,456
    │    │    └─BatchNorm2d: 3-24            [1, 128, 31, 31]          256
    │    │    └─ReLU: 3-25                   [1, 128, 31, 31]          --
    ├─Sequential: 1-7                        [1, 256, 16, 16]          --
    │    └─BasicBlock: 2-5                   [1, 256, 16, 16]          --
    │    │    └─Conv2d: 3-26                 [1, 256, 16, 16]          294,912
    │    │    └─BatchNorm2d: 3-27            [1, 256, 16, 16]          512
    │    │    └─ReLU: 3-28                   [1, 256, 16, 16]          --
    │    │    └─Conv2d: 3-29                 [1, 256, 16, 16]          589,824
    │    │    └─BatchNorm2d: 3-30            [1, 256, 16, 16]          512
    │    │    └─Sequential: 3-31             [1, 256, 16, 16]          33,280
    │    │    └─ReLU: 3-32                   [1, 256, 16, 16]          --
    │    └─BasicBlock: 2-6                   [1, 256, 16, 16]          --
    │    │    └─Conv2d: 3-33                 [1, 256, 16, 16]          589,824
    │    │    └─BatchNorm2d: 3-34            [1, 256, 16, 16]          512
    │    │    └─ReLU: 3-35                   [1, 256, 16, 16]          --
    │    │    └─Conv2d: 3-36                 [1, 256, 16, 16]          589,824
    │    │    └─BatchNorm2d: 3-37            [1, 256, 16, 16]          512
    │    │    └─ReLU: 3-38                   [1, 256, 16, 16]          --
    ├─Sequential: 1-8                        [1, 512, 8, 8]            --
    │    └─BasicBlock: 2-7                   [1, 512, 8, 8]            --
    │    │    └─Conv2d: 3-39                 [1, 512, 8, 8]            1,179,648
    │    │    └─BatchNorm2d: 3-40            [1, 512, 8, 8]            1,024
    │    │    └─ReLU: 3-41                   [1, 512, 8, 8]            --
    │    │    └─Conv2d: 3-42                 [1, 512, 8, 8]            2,359,296
    │    │    └─BatchNorm2d: 3-43            [1, 512, 8, 8]            1,024
    │    │    └─Sequential: 3-44             [1, 512, 8, 8]            132,096
    │    │    └─ReLU: 3-45                   [1, 512, 8, 8]            --
    │    └─BasicBlock: 2-8                   [1, 512, 8, 8]            --
    │    │    └─Conv2d: 3-46                 [1, 512, 8, 8]            2,359,296
    │    │    └─BatchNorm2d: 3-47            [1, 512, 8, 8]            1,024
    │    │    └─ReLU: 3-48                   [1, 512, 8, 8]            --
    │    │    └─Conv2d: 3-49                 [1, 512, 8, 8]            2,359,296
    │    │    └─BatchNorm2d: 3-50            [1, 512, 8, 8]            1,024
    │    │    └─ReLU: 3-51                   [1, 512, 8, 8]            --
    ├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
    ├─Linear: 1-10                           [1, 1000]                 513,000
    ==========================================================================================
    Total params: 11,689,512
    Trainable params: 11,689,512
    Non-trainable params: 0
    Total mult-adds (G): 2.27
    ==========================================================================================
    Input size (MB): 0.71
    Forward/backward pass size (MB): 48.20
    Params size (MB): 46.76
    Estimated Total Size (MB): 95.67
    ==========================================================================================
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
  • 相关阅读:
    主成分分析-书后习题回顾总结
    BUUCTF——Basic题解
    Kubernetes:kube-apiserver 之启动流程(一)
    CefSharp 新版本 C# JS 交互方式
    Redis类型
    clion出现createprocess error=193, %1 不是有效的 win32 应用程序
    Scrum敏捷开发企业培训大纲介绍-企业内训
    CC-Proxy配置实验室网络代理服务器
    DockerFile发布Java微服务并部署到Docker容器
    Self-paced Multi-grained Cross-modal Interaction Modeling for Referring Expression Comprehension论文阅读
  • 原文地址:https://blog.csdn.net/HW140701/article/details/125554060