• python-pytorch 如何使用python库Netron查看模型结构(以pytorch官网模型为例)0.9.1


    参照模型

    以pytorch官网的tutorial为观察对象,链接是https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

    模型代码如下

    import torch.nn as nn
    import torch.nn.functional as F
    
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, output_size):
            super(RNN, self).__init__()
    
            self.hidden_size = hidden_size
    
            self.i2h = nn.Linear(input_size, hidden_size)
            self.h2h = nn.Linear(hidden_size, hidden_size)
            self.h2o = nn.Linear(hidden_size, output_size)
            self.softmax = nn.LogSoftmax(dim=1)
    
        def forward(self, input, hidden):
            hidden = F.tanh(self.i2h(input) + self.h2h(hidden))
            output = self.h2o(hidden)
            output = self.softmax(output)
            return output, hidden
    
        def initHidden(self):
            return torch.zeros(1, self.hidden_size)
    
    n_hidden = 128
    rnn = RNN(n_letters, n_hidden, n_categories)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    安装Netron

    pip install netron即可

    其他安装方式参考链接
    https://blog.csdn.net/m0_49963403/article/details/136242313

    写netron代码

    随便找一个地方打个点,如sample方法中

    import netron
    max_length = 20
    
    # Sample from a category and starting letter
    def sample(category, start_letter='A'):
        with torch.no_grad():  # no need to track history in sampling
            category_tensor = categoryTensor(category)
            input = inputTensor(start_letter)
            hidden = rnn.initHidden()
    
            output_name = start_letter
    
            for i in range(max_length):
    #             print("category_tensor",category_tensor.size())
    #             print("input[0]",input[0].size())
    #             print("hidden",hidden.size())
                
                output, hidden = rnn(category_tensor, input[0], hidden)
                torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件
                netron.start('AlexNet1.onnx') #展示结构图
            
                break
    #             print("output",output.size())
    #             print("hidden",hidden.size())
    #             print("====================")
            
                topv, topi = output.topk(1)
                topi = topi[0][0]
                if topi == n_letters - 1:
                    break
                else:
                    letter = all_letters[topi]
                    output_name += letter
                input = inputTensor(letter)
    
            return output_name
    
    # Get multiple samples from one category and multiple starting letters
    def samples(category, start_letters='ABC'):
        for start_letter in start_letters:
            print(sample(category, start_letter))
            break
    
    samples('Russian', 'RUS')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    运行查看结果

    结果是在浏览器中,运行成功后会显示:
    Serving ‘AlexNet.onnx’ at http://localhost:8080

    打开这个网页就可以看见模型结构,如下图

    在这里插入图片描述

    需要关注的地方

    如果模型是一个参数的情况下,如下使用就可以了

    import torch
    from torchvision.models import AlexNet
    import netron
    model = AlexNet()
    input = torch.ones((1,3,224,224))
    torch.onnx.export(model, input, f='AlexNet.onnx')
    netron.start('AlexNet.onnx')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    如果模型有多个参数的情况下,则需要如下用括号括起来,如本文中的例子

    torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件
    netron.start('AlexNet1.onnx') #展示结构图
    
    • 1
    • 2
  • 相关阅读:
    DNS部署与安全
    [运维|数据库] PostgreSQL数据库对MySQL的 READS SQL DATA 修饰符处理
    tutorial/detailed_workflow.ipynb 量化金融Qlib库
    听,引擎的声音「GitHub 热点速览 v.22.33」
    Elasticsearch:使用向量化和 FFI/madvise 加速 Lucene
    (c语言)二维数组求最大值
    富格林:安全落实防备诱导欺诈建议
    8.2_[Java 方法]-深入带参 数组/对象 作为参数的方法以及 值传递/引用传递 的区别
    python运营商身份证二要素查验接口、身份证实名认证接口
    Linux网络编程- IO多路复用
  • 原文地址:https://blog.csdn.net/m0_60688978/article/details/138190557