• 【TorchScript】PyTorch模型转换为C++支持的模型


    任务简介:

    使用PyTorch训练的模型只能在Python环境中使用,在自动驾驶场景中,模型推理过程通常是在硬件设备上进行。TorchScript可以将PyTorch训练的模型转换为C++环境支持的模型,推理速度比Python环境更快。本文对整体转换流程做一个简单的记录,后续需要补充TorchScript的支持的各种语法规则以及注意点。


    TorchScript

    TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。

    1. 两种TorchScript模型创建方式

    TorchScript模型生成有torch.jit.trace和torch.jit.script两种方法。

    1.1 torch.jit.trace

    传入Module和符合的示例输入。它会调用Moduel并将操作记录下来,当Module运行时记录下操作,然后创建torch.jit.ScriptModule的实例。对于有控制流的模型,直接使用torch.jit.trace()并不能跟踪到控制流,因为它只是对操作进行了记录,对于没有运行到的操作并不会记录,trace方式生成模型的示例如下:

    class MyDecisionGate(torch.nn.Module):
        def forward(self, x: Tensor) -> Tensor:
            if x.sum() > 0:
                return x
            else:
                return -x
    
    class MyCell(torch.nn.Module):
        def __init__(self, dg):
            super(MyCell, self).__init__()
            self.dg = dg
            self.linear = torch.nn.Linear(4, 4)
    
        def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]:
            new_h = torch.tanh(self.dg(self.linear(x)) + h)
            return new_h, new_h
    
    my_cell = MyCell(MyDecisionGate())
    x, h = torch.rand(3, 4), torch.rand(3, 4)
    traced_cell = torch.jit.trace(my_cell, (x, h))  # trace方式
    
    print(traced_cell.dg.code)
    print(traced_cell.code)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    输出:

    def forward(self,
        argument_1: Tensor) -> None:
      return None
    
    def forward(self,
        input: Tensor,
        h: Tensor) -> Tuple[Tensor, Tensor]:
      _0 = self.dg
      _1 = (self.linear).forward(input, )
      _2 = (_0).forward(_1, )
      _3 = torch.tanh(torch.add(_1, h, alpha=1))
      return (_3, _3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    可以看到.code的输出,if-else的分支没有了,控制流会被擦除。

    1.2 torch.jit.script

    前面提到的问题,可以使用script compiler来解决,可以直接分析Python源代码来把它转化为TrochScript。如下:

    scripted_gate = torch.jit.script(MyDecisionGate())  # script方式
    my_cell = MyCell(scripted_gate)
    scripted_cell = torch.jit.script(my_cell)  # script方式
    
    print(scripted_gate.code)
    print(scripted_cell.code)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输出:

    def forward(self,
        x: Tensor) -> Tensor:
      _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
      if _0:
        _1 = x
      else:
        _1 = torch.neg(x)
      return _1
    
    def forward(self,
        x: Tensor,
        h: Tensor) -> Tuple[Tensor, Tensor]:
      _0 = (self.dg).forward((self.linear).forward(x, ), )
      new_h = torch.tanh(torch.add(_0, h, alpha=1))
      return (new_h, new_h)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    可以看到控制流保存下来了。

    1.2 DenseTNT的TorchScript模型生成

    torch.jit.script()会转换传入Module的所有代码,在实际转换模型的过程中会增加修改代码的工作量,因此通常将torch.jit.trace()和torch.jit.script()进行混合使用,比较灵活。

    在需要使用控制流,如不定长的for循环、if-else分支时,在该函数上方输入@torch.jit.script 即可,如:

    @torch.jit.script
    def get_goal_2D(topk_lane_vector: Tensor, topk_points_mask: Tensor) -> Tensor:
        points = torch.zeros([1,2],device=topk_lane_vector.device)
        visit: Dict[int,bool]= {}
        for index_lane, lane_vector in enumerate(topk_lane_vector):
            for i, point in enumerate(lane_vector):
                if topk_points_mask[index_lane][i]:
                    hash: int = int(torch.round((point[0] + 500) * 100) * 1000000 + torch.round((point[1] + 500) * 100))
                    if hash not in visit:
                        visit[hash] = True
                        points = torch.cat([points,point.unsqueeze(0)],dim=0)
            point_num, divide_num = _get_subdivide_num(lane_vector, topk_points_mask[index_lane]) 
            if divide_num > 1:
                subdivide_points = _get_subdivide_points(lane_vector, point_num, divide_num)
                points = torch.cat([points,subdivide_points],dim=0)
        return points[1:]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    再使用torch.jit.trace()将实例化的model和输入传入,即可生成TorchScript模型。

    示例代码:

    model.eval()
    with torch.no_grad():
        traced_script_model = torch.jit.trace(model, script_inputs, strict=False)
        traced_script_model.save("models.densetnt.1/model_save/model.16_script.bin")
    print(traced_script_model.code)
    print('Finish converting model!!!')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
  • 相关阅读:
    统计学习方法-感知机
    弘辽科技:超级推荐爆款拉新怎么设置?爆款拉新怎么玩?
    Redis订阅发布
    源码安装LAMT架构
    【Java+SSM】校园外卖配送系统(外卖点餐系统、在线点餐系统)
    MySQL之分库分表(二)实践
    window.addEventListener相关参数介绍说明
    【虚拟仿真】Unity3D中实现3DUI,并且实现Button、InputField、Toggle等事件绑定
    Vue中的数据绑定
    Hexo博客使用aplayer音乐播放插件
  • 原文地址:https://blog.csdn.net/weixin_40633696/article/details/125557396