• pytorch 模型部署之Libtorch


    Python端生成pt模型文件

    net.load(model_path)
    net.eval()
    net.to("cuda")
    
    example_input = torch.rand(1, 3, 240, 320).to("cuda")
    traced_model = torch.jit.trace(net, example_input)
    traced_model.save("model.pt")
    
    
    output = traced_model(example_input)
    # 输出查看是否与c++输出一致。
    print(len(output))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    C++ 端进行调用

    c++环境配置
    libtorch常用API

    #include 
    #include 
    
    #include 
    
    int main() {
    	std::cout <<"cuda::is_available():" << torch::cuda::is_available() << std::endl;
        torch::Tensor tensor = torch::rand({2, 3}).to(at::kCUDA);
        std::cout << tensor << std::endl;
     
           torch::jit::script::Module module;
         
            module = torch::jit::load("/home/yang/Documents/demo/opencv/model.pt");
        
    
           // 创建一个示例输入
           std::vector<torch::jit::IValue> inputs;
           inputs.push_back(torch::rand({1, 3, 240, 320}).to(at::kCUDA));
    
           // 运行模型
          // at::Tensor output = module.forward(inputs).toTensor();
            //auto output = module.forward(inputs).toTensorList();
            auto out = module.forward(inputs);
     
            auto tpl = out.toTuple();
    
            auto out_ct_hm = tpl->elements()[0].toTensor();
            out_ct_hm.print();
            auto out_wh = tpl->elements()[1].toTensor();
            out_wh.print();
    
    
           // 打印输出
           //std::cout << output << "\n";
    
    }
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    可能出错的问题

    1. terminate called after throwing an instance of ‘c10::Error’
      what(): open file failed, file path: model.pt (FileAdapter at …/…/caffe2/serialize/file_adapter.cc:11)。 模型路径有问题,使用绝对路径解决。
    2. ‘c10::Error’ what(): isTensor() INTERNAL ASSERT FAILED。
      很明显,模型的输出应该不是一个 Tensor,可能是一个列表或者元组什么的
  • 相关阅读:
    网络安全大厂面试题汇总
    Reactor反应器模式
    9. JVM-方法区
    make编译出错Relocations in generic ELF (EM: 62)
    Java--字节内存流--ByteArrayInputStream与ByteArrayOutputStream
    域名系统DNS
    录制线上课程,有哪些形式,到底使用什么软件好?
    音视频开发需要你懂得的 H264 编码原理
    是使用local_setup.bash 还是 setup.bash
    Unix后记&寻找Shen Lin
  • 原文地址:https://blog.csdn.net/u011489887/article/details/133798294