• 【Pytorch】(十四)C++ 加载TorchScript 模型



    【Pytorch】(十三)PyTorch模型部署: TorchScript

    (十四)C++ 加载TorchScript 模型

    以下内容将介绍如何在C++环境下加载和运行TorchScript 模型。

    Step 1: 将PyTorch模型转换为TorchScript

    将resnet18模型的一个实例以及示例输入传递给torch.jit.trace函数, 将模型转换为TorchScript:

    import torch
    import torchvision
    
    # An instance of your model.
    model = torchvision.models.resnet18()
    
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1, 3, 224, 224)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    Step 2: 将TorchScript序列化为文件

    序列化TorchScript并保存:

    traced_script_module.save("traced_resnet_model.pt")
    
    • 1

    这将在工作目录中生成traced_resnet_model.pt文件。

    Step 3: C++程序中加载TorchScript模型

    要在C++中加载序列化的TorchScript模型,必须依赖于PyTorch C++API(也称为LibTorch)。最新的稳定版本的LibTorch可以从PyTorch官网下载。

    以下命令可以下载CPU版本的:

    wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
    unzip libtorch-shared-with-deps-latest.zip
    
    • 1
    • 2

    下载并解压缩后,可以得到一个具有以下目录结构的文件夹:

    libtorch/
      bin/
      include/
      lib/
      share/
    
    • 1
    • 2
    • 3
    • 4
    • 5

    lib/文件夹包含必须链接的共享库,

    include/文件夹包含程序需要包含的头文件,

    share/文件夹包含必要的CMake配置。

    下面将使用CMake和LibTorch构建一个C++应用程序,该应用程序加载并执行一个序TorchScript模型。

    example-app.cpp

    #include  
    #include 
    #include 
    
    int main(int argc, const char* argv[]) {
      if (argc != 2) {
        std::cerr << "usage: example-app \n";
        return -1;
      }
    
      torch::jit::script::Module module;
      try {
      	// 对TorchScript进行反序列化,该函数以文件路径作为输入
        module = torch::jit::load(argv[1]);
      }
      catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
      }
    
      std::cout << "ok\n";
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    CMakeLists.txt

    cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
    project(custom_ops)
    
    find_package(Torch REQUIRED)
    
    add_executable(example-app example-app.cpp)
    target_link_libraries(example-app "${TORCH_LIBRARIES}")
    set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    假设我们的示例目录如下所示:

    example-app/
      CMakeLists.txt
      example-app.cpp
    
    • 1
    • 2
    • 3

    我们可以运行以下命令,构建应用程序:

    mkdir build
    cd build
    cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
    cmake --build . --config Release
    
    • 1
    • 2
    • 3
    • 4

    注意:GCC版本需要不小于9,不然编译会出错。其中/path/to/libtorch应该是解压缩的libtorch的完整路径。

    如果一切顺利,打印的信息会是这样的:

    root@4b5a67132e81:/example-app# mkdir build
    root@4b5a67132e81:/example-app# cd build
    root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
    -- The C compiler identification is GNU 5.4.0
    -- The CXX compiler identification is GNU 5.4.0
    -- Check for working C compiler: /usr/bin/cc
    -- Check for working C compiler: /usr/bin/cc -- works
    -- Detecting C compiler ABI info
    -- Detecting C compiler ABI info - done
    -- Detecting C compile features
    -- Detecting C compile features - done
    -- Check for working CXX compiler: /usr/bin/c++
    -- Check for working CXX compiler: /usr/bin/c++ -- works
    -- Detecting CXX compiler ABI info
    -- Detecting CXX compiler ABI info - done
    -- Detecting CXX compile features
    -- Detecting CXX compile features - done
    -- Looking for pthread.h
    -- Looking for pthread.h - found
    -- Looking for pthread_create
    -- Looking for pthread_create - not found
    -- Looking for pthread_create in pthreads
    -- Looking for pthread_create in pthreads - not found
    -- Looking for pthread_create in pthread
    -- Looking for pthread_create in pthread - found
    -- Found Threads: TRUE
    -- Configuring done
    -- Generating done
    -- Build files have been written to: /example-app/build
    root@4b5a67132e81:/example-app/build# make
    Scanning dependencies of target example-app
    [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
    [100%] Linking CXX executable example-app
    [100%] Built target example-app
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    运行程序:

    root@4b5a67132e81:/example-app/build# ./example-app /traced_resnet_model.pt
    ok
    
    • 1
    • 2

    打印ok说明加载成功。

    Step 4: C++程序中运行TorchScript模型

    将Step 1相同的 inputs 输入到C++加载的模型:

    #include  
    #include 
    #include 
    
    int main(int argc, const char* argv[]) {
      if (argc != 2) {
        std::cerr << "usage: example-app \n";
        return -1;
      }
      torch::jit::script::Module module;
      try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load(argv[1]);
      }
      catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
      }
    
      std::cout << "ok\n";
      // Create a vector of inputs.
      std::vector<torch::jit::IValue> inputs;
      inputs.push_back(torch::ones({1, 3, 224, 224}));
      
      // Execute the model and turn its output into a tensor.
      at::Tensor output = module.forward(inputs).toTensor();
      std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    root@4b5a67132e81:/example-app/build# make
    Scanning dependencies of target example-app
    [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
    [100%] Linking CXX executable example-app
    [100%] Built target example-app
    root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
    -0.2698 -0.0381  0.4023 -0.3010 -0.0448
    [ Variable[CPUFloatType]{1,5} ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    可以看到,C++环境下模型的输出与Python环境下的相同:

    tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
    
    • 1

    更多功能的实现可查阅:

    PyTorch C++API文档:

    PyTorch Python API文档:

    TorchScript Pytorch官方文档

    参考:
    https://pytorch.org/tutorials/advanced/cpp_export.html

  • 相关阅读:
    正则表达式2
    记录一次慢SQL优化:大表关联小表->拆解为单表查询
    Maven POM:掌握项目对象模型的艺术
    【STM32学习】通用定时器的应用实验
    【状态估计】无迹卡尔曼滤波(UKF)应用于FitzHugh-Nagumo神经元动力学研究(Matlab代码实现)
    Java 格式化时间与时间戳与时间间隔
    Chaos Vantage最低配置要求,需要什么显卡
    贪心算法解决雷达站建站问题
    本地连接服务器mysql数据库慢
    08架构管理之架构检查方法
  • 原文地址:https://blog.csdn.net/weixin_44378835/article/details/138161199