任务简介:
使用PyTorch训练的模型只能在Python环境中使用,在自动驾驶场景中,模型推理过程通常是在硬件设备上进行。TorchScript可以将PyTorch训练的模型转换为C++环境支持的模型,推理速度比Python环境更快。本文对整体转换流程做一个简单的记录,后续需要补充TorchScript的支持的各种语法规则以及注意点。
TorchScript:
TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。
TorchScript模型生成有torch.jit.trace和torch.jit.script两种方法。
传入Module和符合的示例输入。它会调用Moduel并将操作记录下来,当Module运行时记录下操作,然后创建torch.jit.ScriptModule的实例。对于有控制流的模型,直接使用torch.jit.trace()并不能跟踪到控制流,因为它只是对操作进行了记录,对于没有运行到的操作并不会记录,trace方式生成模型的示例如下:
class MyDecisionGate(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]:
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell(MyDecisionGate())
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h)) # trace方式
print(traced_cell.dg.code)
print(traced_cell.code)
输出:
def forward(self,
argument_1: Tensor) -> None:
return None
def forward(self,
input: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = self.dg
_1 = (self.linear).forward(input, )
_2 = (_0).forward(_1, )
_3 = torch.tanh(torch.add(_1, h, alpha=1))
return (_3, _3)
可以看到.code的输出,if-else的分支没有了,控制流会被擦除。
前面提到的问题,可以使用script compiler来解决,可以直接分析Python源代码来把它转化为TrochScript。如下:
scripted_gate = torch.jit.script(MyDecisionGate()) # script方式
my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell) # script方式
print(scripted_gate.code)
print(scripted_cell.code)
输出:
def forward(self,
x: Tensor) -> Tensor:
_0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
if _0:
_1 = x
else:
_1 = torch.neg(x)
return _1
def forward(self,
x: Tensor,
h: Tensor) -> Tuple[Tensor, Tensor]:
_0 = (self.dg).forward((self.linear).forward(x, ), )
new_h = torch.tanh(torch.add(_0, h, alpha=1))
return (new_h, new_h)
可以看到控制流保存下来了。
torch.jit.script()会转换传入Module的所有代码,在实际转换模型的过程中会增加修改代码的工作量,因此通常将torch.jit.trace()和torch.jit.script()进行混合使用,比较灵活。
在需要使用控制流,如不定长的for循环、if-else分支时,在该函数上方输入@torch.jit.script
即可,如:
@torch.jit.script
def get_goal_2D(topk_lane_vector: Tensor, topk_points_mask: Tensor) -> Tensor:
points = torch.zeros([1,2],device=topk_lane_vector.device)
visit: Dict[int,bool]= {}
for index_lane, lane_vector in enumerate(topk_lane_vector):
for i, point in enumerate(lane_vector):
if topk_points_mask[index_lane][i]:
hash: int = int(torch.round((point[0] + 500) * 100) * 1000000 + torch.round((point[1] + 500) * 100))
if hash not in visit:
visit[hash] = True
points = torch.cat([points,point.unsqueeze(0)],dim=0)
point_num, divide_num = _get_subdivide_num(lane_vector, topk_points_mask[index_lane])
if divide_num > 1:
subdivide_points = _get_subdivide_points(lane_vector, point_num, divide_num)
points = torch.cat([points,subdivide_points],dim=0)
return points[1:]
再使用torch.jit.trace()将实例化的model和输入传入,即可生成TorchScript模型。
示例代码:
model.eval()
with torch.no_grad():
traced_script_model = torch.jit.trace(model, script_inputs, strict=False)
traced_script_model.save("models.densetnt.1/model_save/model.16_script.bin")
print(traced_script_model.code)
print('Finish converting model!!!')