• Pytorch实现RNN预测模型并使用C++相应的ONNX模型推理


    Pytorch实现RNN模型

    代码

    import torch
    import torch.nn as nn
    
    class RNN(nn.Module):
        def __init__(self, seq_len, input_size, hidden_size, output_size, num_layers, device):
            super(RNN, self).__init__()
            self._seq_len = seq_len
            self._input_size = input_size
            self._output_size = output_size
            self._hidden_size = hidden_size
            self._device = device
            self._num_layers = num_layers
    
            self.rnn = nn.RNN(
                input_size=input_size,
                hidden_size=self._hidden_size,
                num_layers=self._num_layers,
                batch_first=True
            )
    
            self.fc = nn.Linear(self._seq_len * self._hidden_size, self._output_size)
    
        def forward(self, x, hidden_prev):
            out, hidden_prev = self.rnn(x, hidden_prev)
            out = out.contiguous().view(out.shape[0], -1)
            out = self.fc(out)
            return out, hidden_prev
    
    seq_len = 10
    batch_size = 20
    input_size = 10
    output_size = 10
    hidden_size = 32
    num_layers = 2
    model = RNN(seq_len, input_size, hidden_size, output_size, num_layers, "cpu")
    hidden_prev = torch.zeros(num_layers, batch_size, hidden_size).to("cpu")
    model.eval() 
    
    input_names = ["input", "hidden_prev_in"]
    output_names  = ["output", "hidden_prev_out"]
    
    x = torch.randn((batch_size, seq_len, input_size))
    y, hidden_prev = model(x, hidden_prev)
    print(x.shape)
    print(hidden_prev.shape)
    print(y.shape)
    print(hidden_prev.shape)
    
    torch.onnx.export(model, (x, hidden_prev), 'RNN.onnx', verbose=True, input_names=input_names, output_names=output_names,
      dynamic_axes={'input':[0], 'hidden_prev_in':[1], 'output':[0], 'hidden_prev_out':[1]} )
    
    import onnx
    model = onnx.load("RNN.onnx")
    print("load model done.")
    onnx.checker.check_model(model)
    print(onnx.helper.printable_graph(model.graph))
    print("check model done.")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57

    运行结果

    torch.Size([20, 10, 10])
    torch.Size([2, 20, 32])
    torch.Size([20, 10])
    torch.Size([2, 20, 32])
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input
      "No names were found for specified dynamic axes of provided input."
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input hidden_prev
      "No names were found for specified dynamic axes of provided input."
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output
      "No names were found for specified dynamic axes of provided input."
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with RNN_TANH can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. 
      + "or define the initial states (h0/c0) as inputs of the model. "
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
      _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
      graph, params_dict, GLOBALS.export_onnx_opset_version
    /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
      graph, params_dict, GLOBALS.export_onnx_opset_version
    Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),
          %hidden_prev.1 : Float(2, *, 32, strides=[640, 32, 1], requires_grad=1, device=cpu),
          %fc.weight : Float(10, 320, strides=[320, 1], requires_grad=1, device=cpu),
          %fc.bias : Float(10, strides=[1], requires_grad=1, device=cpu),
          %onnx::RNN_58 : Float(1, 32, 10, strides=[320, 10, 1], requires_grad=0, device=cpu),
          %onnx::RNN_59 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
          %onnx::RNN_60 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu),
          %onnx::RNN_62 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
          %onnx::RNN_63 : Float(1, 32, 32, strides=[1024, 32, 1], requires_grad=0, device=cpu),
          %onnx::RNN_64 : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu)):
      %/rnn/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/rnn/Transpose"](%input), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %onnx::RNN_13 : Tensor? = prim::Constant(), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant_1"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_2"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Slice_output_0 : Float(1, *, 32, device=cpu) = onnx::Slice[onnx_name="/rnn/Slice"](%hidden_prev.1, %/rnn/Constant_1_output_0, %/rnn/Constant_2_output_0, %/rnn/Constant_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/RNN_output_0 : Float(10, 1, *, 32, device=cpu), %/rnn/RNN_output_1 : Float(1, *, 32, device=cpu) = onnx::RNN[activations=["Tanh"], hidden_size=32, onnx_name="/rnn/RNN"](%/rnn/Transpose_output_0, %onnx::RNN_58, %onnx::RNN_59, %onnx::RNN_60, %onnx::RNN_13, %/rnn/Slice_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_3"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Squeeze_output_0 : Float(10, *, 32, device=cpu) = onnx::Squeeze[onnx_name="/rnn/Squeeze"](%/rnn/RNN_output_0, %/rnn/Constant_3_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/rnn/Constant_4"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_5"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/rnn/Constant_6"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Slice_1_output_0 : Float(1, *, 32, device=cpu) = onnx::Slice[onnx_name="/rnn/Slice_1"](%hidden_prev.1, %/rnn/Constant_5_output_0, %/rnn/Constant_6_output_0, %/rnn/Constant_4_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/RNN_1_output_0 : Float(10, 1, *, 32, device=cpu), %/rnn/RNN_1_output_1 : Float(1, *, 32, device=cpu) = onnx::RNN[activations=["Tanh"], hidden_size=32, onnx_name="/rnn/RNN_1"](%/rnn/Squeeze_output_0, %onnx::RNN_62, %onnx::RNN_63, %onnx::RNN_64, %onnx::RNN_13, %/rnn/Slice_1_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/rnn/Constant_7"](), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Squeeze_1_output_0 : Float(10, *, 32, device=cpu) = onnx::Squeeze[onnx_name="/rnn/Squeeze_1"](%/rnn/RNN_1_output_0, %/rnn/Constant_7_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/rnn/Transpose_1_output_0 : Float(*, 10, 32, strides=[320, 32, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/rnn/Transpose_1"](%/rnn/Squeeze_1_output_0), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %hidden_prev : Float(2, *, 32, strides=[640, 32, 1], requires_grad=1, device=cpu) = onnx::Concat[axis=0, onnx_name="/rnn/Concat"](%/rnn/RNN_output_1, %/rnn/RNN_1_output_1), scope: __main__.RNN::/torch.nn.modules.rnn.RNN::rnn # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:478:0
      %/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%/rnn/Transpose_1_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
      %/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
      %/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
      %onnx::Unsqueeze_50 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
      %/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_50), scope: __main__.RNN::
      %/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_1"](), scope: __main__.RNN::
      %/Concat_output_0 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Unsqueeze_output_0, %/Constant_1_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
      %/Reshape_output_0 : Float(*, *, strides=[320, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/rnn/Transpose_1_output_0, %/Concat_output_0), scope: __main__.RNN:: # /zengli/20230320/ao/test/test_onnx_rnn.py:25:0
      %output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/fc/Gemm"](%/Reshape_output_0, %fc.weight, %fc.bias), scope: __main__.RNN::/torch.nn.modules.linear.Linear::fc # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/linear.py:114:0
      return (%output, %hidden_prev)
    
    load model done.
    graph torch_jit (
      %input[FLOAT, input_dynamic_axes_1x10x10]
      %hidden_prev.1[FLOAT, 2xhidden_prev.1_dim_1x32]
    ) initializers (
      %fc.weight[FLOAT, 10x320]
      %fc.bias[FLOAT, 10]
      %onnx::RNN_58[FLOAT, 1x32x10]
      %onnx::RNN_59[FLOAT, 1x32x32]
      %onnx::RNN_60[FLOAT, 1x64]
      %onnx::RNN_62[FLOAT, 1x32x32]
      %onnx::RNN_63[FLOAT, 1x32x32]
      %onnx::RNN_64[FLOAT, 1x64]
    ) {
      %/rnn/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)
      %/rnn/Constant_output_0 = Constant[value = ]()
      %/rnn/Constant_1_output_0 = Constant[value = ]()
      %/rnn/Constant_2_output_0 = Constant[value = ]()
      %/rnn/Slice_output_0 = Slice(%hidden_prev.1, %/rnn/Constant_1_output_0, %/rnn/Constant_2_output_0, %/rnn/Constant_output_0)
      %/rnn/RNN_output_0, %/rnn/RNN_output_1 = RNN[activations = ['Tanh'], hidden_size = 32](%/rnn/Transpose_output_0, %onnx::RNN_58, %onnx::RNN_59, %onnx::RNN_60, %, %/rnn/Slice_output_0)
      %/rnn/Constant_3_output_0 = Constant[value = ]()
      %/rnn/Squeeze_output_0 = Squeeze(%/rnn/RNN_output_0, %/rnn/Constant_3_output_0)
      %/rnn/Constant_4_output_0 = Constant[value = ]()
      %/rnn/Constant_5_output_0 = Constant[value = ]()
      %/rnn/Constant_6_output_0 = Constant[value = ]()
      %/rnn/Slice_1_output_0 = Slice(%hidden_prev.1, %/rnn/Constant_5_output_0, %/rnn/Constant_6_output_0, %/rnn/Constant_4_output_0)
      %/rnn/RNN_1_output_0, %/rnn/RNN_1_output_1 = RNN[activations = ['Tanh'], hidden_size = 32](%/rnn/Squeeze_output_0, %onnx::RNN_62, %onnx::RNN_63, %onnx::RNN_64, %, %/rnn/Slice_1_output_0)
      %/rnn/Constant_7_output_0 = Constant[value = ]()
      %/rnn/Squeeze_1_output_0 = Squeeze(%/rnn/RNN_1_output_0, %/rnn/Constant_7_output_0)
      %/rnn/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/rnn/Squeeze_1_output_0)
      %hidden_prev = Concat[axis = 0](%/rnn/RNN_output_1, %/rnn/RNN_1_output_1)
      %/Shape_output_0 = Shape(%/rnn/Transpose_1_output_0)
      %/Constant_output_0 = Constant[value = ]()
      %/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)
      %onnx::Unsqueeze_50 = Constant[value = ]()
      %/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_50)
      %/Constant_1_output_0 = Constant[value = ]()
      %/Concat_output_0 = Concat[axis = 0](%/Unsqueeze_output_0, %/Constant_1_output_0)
      %/Reshape_output_0 = Reshape[allowzero = 0](%/rnn/Transpose_1_output_0, %/Concat_output_0)
      %output = Gemm[alpha = 1, beta = 1, transB = 1](%/Reshape_output_0, %fc.weight, %fc.bias)
      return %output, %hidden_prev
    }
    check model done.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100

    C++调用ONNX

    代码

    vector<float> testOnnxRNN() {
        //设置为VERBOSE,方便控制台输出时看到是使用了cpu还是gpu执行
        //Ort::Env env(ORT_LOGGING_LEVEL_VERBOSE, "test");
        Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");
        Ort::SessionOptions session_options;
    
        session_options.SetIntraOpNumThreads(5); // 使用五个线程执行op,提升速度
        // 第二个参数代表GPU device_id = 0,注释这行就是cpu执行
        //OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);
        session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
    
        #ifdef _WIN32
            const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\RNN.onnx";
        #else
            const char* model_path = "C:\\Users\\xxx\\Desktop\\RNN.onnx";
        #endif
    
        wprintf(L"%s\n", model_path);
    
        Ort::Session session(env, model_path, session_options);
        Ort::AllocatorWithDefaultOptions allocator;
    
        size_t num_input_nodes = session.GetInputCount();
        size_t num_output_nodes = session.GetOutputCount();
    
        std::vector<const char*> input_node_names = { "input" , "hidden_prev_in" }; 
        std::vector<const char*> output_node_names = { "output" , "hidden_prev_out" };
    
        const int input_size = 10;
        const int output_size = 10;
        const int batch_size = 1;
        const int seq_len = 10;
        const int num_layers = 2;
        const int hidden_size = 32;
    
        std::vector<int64_t> input_node_dims = { batch_size, seq_len, input_size };
        size_t input_tensor_size = batch_size * seq_len * input_size;
        std::vector<float> input_tensor_values(input_tensor_size);
        for (unsigned int i = 0; i < input_tensor_size; i++) {
            input_tensor_values[i] = (float)i / (input_tensor_size + 1);
        }
        auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
        Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), 3);
        assert(input_tensor.IsTensor());
    
        std::vector<int64_t> hidden_prev_in_node_dims = { num_layers, batch_size, hidden_size };
        size_t hidden_prev_in_tensor_size = num_layers * batch_size * hidden_size;
        std::vector<float> hidden_prev_in_tensor_values(hidden_prev_in_tensor_size);
        for (unsigned int i = 0; i < hidden_prev_in_tensor_size; i++) {
            hidden_prev_in_tensor_values[i] = (float)i / (hidden_prev_in_tensor_size + 1);
        }
        auto mask_memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
        Ort::Value hidden_prev_in_tensor = Ort::Value::CreateTensor<float>(mask_memory_info, hidden_prev_in_tensor_values.data(), hidden_prev_in_tensor_size, hidden_prev_in_node_dims.data(), 3);
        assert(hidden_prev_in_tensor.IsTensor());
    
        std::vector<Ort::Value> ort_inputs;
        ort_inputs.push_back(std::move(input_tensor));
        ort_inputs.push_back(std::move(hidden_prev_in_tensor));
    
        vector<float> ret;
        try
        {
            auto output_tensors = session.Run(Ort::RunOptions{ nullptr }, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), 2);
            float* output = output_tensors[0].GetTensorMutableData<float>();
            float* hidden_prev_out = output_tensors[1].GetTensorMutableData<float>();
               
            // output
            for (int i = 0; i < output_size; i++) {
                ret.emplace_back(output[i]);
                std::cout << output[i] << " ";
            }
            std::cout << "\n";
    
            // hidden_prev_out
            //for (int i = 0; i < num_layers * batch_size * hidden_size; i++) {
            //    std::cout << hidden_prev_out[i] << "\t";
            //}
            //std::cout << "\n";
        }
        catch (const std::exception& e)
        {
            std::cout << e.what() << std::endl;
        }
        return ret;
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    运行结果

    C:\Users\xxx\Desktop\RNN.onnx
    0.00296116 0.104443 -0.104239 0.249864 -0.155839 0.019295 0.0458037 -0.0596341 -0.129019 -0.014682
    
    • 1
    • 2
  • 相关阅读:
    mybatis二级缓存机制及开启
    虚拟环境与Django版本区别
    【JVM】JVM详解
    qt的xml读写和QDomDocument、QDomElement、QDomNode、QDomNamedNodeMap讲解
    数据结构:AVL树的实现和全部图解
    AI数据标注猿知识星球私域社区开始招募啦
    基于linux系统的CAN总线移动机器人- 板子烧入linux系统
    使用pytorch实现深度可分离卷积改进模型的实战实践
    什么软件可以语音转文字?快把这些软件收好
    人工智能驱动的自然语言处理:解锁文本数据的价值
  • 原文地址:https://blog.csdn.net/wydxry/article/details/132909323