• pytorch模型可视化的方法总结



    这里主要介绍pytorch 模型的网络结构的可视化
    以 SRCNN 为例子来说明可视化的方法,以及参数量的计算

    模型所占内存 = (参数量内存,特征图内存),
    模型计算量 = (浮点数计算量)

    1. torchsummary

    torchinfo

    class SRCNN(nn.Module):
        def __init__(self, num_channels=1):
            super(SRCNN, self).__init__()
            self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
            self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
            self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x = self.relu(self.conv1(x))
            x = self.relu(self.conv2(x))
            x = self.conv3(x)
            return x
    
    
    from torchinfo import summary
    if __name__ == "__main__":
        modelviz = SRCNN()
        # 打印模型结构
        print(modelviz)
        summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
        for p in modelviz.parameters():
            if p.requires_grad:
                print(p.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    可以得到的结果如下
    在这里插入图片描述
    具体什么含义呢?
    接下来详细解释:
    这里输入以 input_size=(8, 1, 8, 8) 为例子,
    1) kernel shape 和 output shape 就是滤波器的参数shape 和 中间层的一些输出的 shape
    2) Para # 表示的是有多少个参数,计算conv-2d 1-1的参数量,kernelshape = [9,9]:

    W + b        = 5248
    9*9*64 + 64 = 5248
    
    • 1
    • 2

    3) Multi-Adds : 统计的是浮点数运算, 计算conv-2d 1-1的计算量(浮点数运算次数):
    filter(h, w, bias, channel), input(h, w, channel)

    (9*9 + 1) * 64                *    (8 * 8 * 8) = 2686976
    
    • 1

    4) Total params, Total mult-adds (M) 就是对 上面参数的求和

    比如 5248+51232+801 = 57281
    
    • 1

    5)关于size:统计的是 参数 加上 中间层的 占用内存
    输入内存Input size (MB): 0.00

    8*1*8*8  * 4  / 10000008*1*8*8float,每个4Byte, 除以一百万  ,约等于 0    
    
    • 1

    中间特征内存Forward/backward pass size (MB): 0.40

    8*8*8 *1+64+64+32+32+1=  99328
    99328 * 4 / 1000000 = 0.397312
    
    • 1
    • 2

    参数weight内存Params size (MB): 0.23

    57281*4 / 1000000 = 0.229124
    
    • 1

    总内存Estimated Total Size (MB): 0.63

    0.4 + 0.23
    
    • 1

    2. graphviz, torchviz

    from torchviz import make_dot
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    modelviz = SRCNN().to(device)
    input = torch.rand(8, 1, 8, 8).to(device)
    out = modelviz(input)
    print(out.shape)
    
    # 1. 使用 torchviz 可视化
    g = make_dot(out)
    g.view()  # 直接在当前路径下保存 pdf 并打开
    # g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    可视化结果是一个pdf,如下:写了比较多的步骤,所以网络结构感觉不是很清晰
    在这里插入图片描述

    3. 保存成pt文件后使用netron可视化

    netron github:
    安装:

    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple netron
    
    • 1

    代码:

    torch.save(modelviz, "modelviz.pt")
    
    import netron
    modelData = 'modelviz.pt'
    netron.start(modelData)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    点击链接在浏览器中打开
    在这里插入图片描述
    在这里插入图片描述

    4. tensorwatch

    import tensorwatch as tw
    # 3. 使用tensorwatch可视化
    print(tw.model_stats(modelviz, (8, 1, 8, 8)))
    tw.draw_model(modelviz, input)
    
    • 1
    • 2
    • 3
    • 4

    打印的结果如图,可以和 summary 进行对比
    在这里插入图片描述

    5. get_model_complexity_info计算 FLOPs和parameters

        # 4. get_model_complexity_info
        from ptflops import get_model_complexity_info
        macs, params = get_model_complexity_info(modelviz, ( 1, 8, 8), verbose=True, print_per_layer_stat=True)
        print(macs, params)
        params = float(params[:-3])
        macs = float(macs[:-4])
    
        print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    结果:
    在这里插入图片描述

    6. 附上直接可以执行的code

    from torch import nn
    import torch
    from torchviz import make_dot
    import tensorwatch as tw
    from torchinfo import summary
    import netron
    
    class SRCNN(nn.Module):
        def __init__(self, num_channels=1):
            super(SRCNN, self).__init__()
            self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
            self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
            self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x = self.relu(self.conv1(x))
            x = self.relu(self.conv2(x))
            x = self.conv3(x)
            return x
    
    
    
    if __name__ == "__main__":
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # device = 'cpu'
        modelviz = SRCNN().to(device)
        # 打印模型结构
        print(modelviz)
        summary(modelviz, input_size=(8, 1, 8, 8), col_names=["kernel_size", "output_size", "num_params", "mult_adds"])
        for p in modelviz.parameters():
            if p.requires_grad:
                print(p.shape)
        # 创建输入, 看看输出结果
    
        input = torch.rand(8, 1, 8, 8).to(device)
        out = modelviz(input)
        print('out:', out.shape)
        # 1. 使用 torchviz 可视化
        g = make_dot(out)
        g.view()  # 直接在当前路径下保存 pdf 并打开
        # g.render(filename='netStructure/myNetModel', view=False, format='pdf')  # 保存 pdf 到指定路径不打开
    
    
        # 2. 保存成pt文件后进行可视化
        torch.save(modelviz, "modelviz.pt")
        modelData = 'modelviz.pt'
        netron.start(modelData)
    
        # 3. 使用tensorwatch可视化
        # print(tw.model_stats(modelviz, (8, 1, 8, 8)))
        # tw.draw_model(modelviz, input)
    
    
        # 4. get_model_complexity_info
        from ptflops import get_model_complexity_info
        macs, params = get_model_complexity_info(modelviz, (1, 8, 8), verbose=True, print_per_layer_stat=True)
        print(macs, params)
        params = float(params[:-3])
        macs = float(macs[:-4])
    
        print(macs * 8, params) # 8个图像的 FLOPs, 这里的结果 和 其他方法应该一致
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    7. 参考

    超实用的7种 pytorch 网络可视化方法,进来收藏一波
    使用pytorchviz和Netron可视化pytorch网络结构

    https://cloud.tencent.com/developer/article/1842049

  • 相关阅读:
    86.(cesium篇)cesium叠加面接收阴影效果(gltf模型)
    c语言训练7
    Linux命令记载
    torch.cuda.OutOfMemoryError: CUDA out of memory.
    elementui 菜单选中优化
    找实习之从0开始的后端学习日记【9.19】
    巧用clang 的sanitize解决realloc,malloc,calloc失败
    数据结构与算法(Python)
    SpringBoot——原理(起步依赖+自动配置(概述和案例))
    物联网之ESP32与微信小程序实现指示灯、转向灯
  • 原文地址:https://blog.csdn.net/tywwwww/article/details/126782766