• torch.jit.trace与torch.jit.script


    Script mode通过torch.jit.trace或者torch.jit.script来调用。这两个函数都是将python代码转换为TorchScript的两种不同的方法。

    torch.jit.trace将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个PyTorch模型torch.jit.trace会跟踪此input在model中的计算过程,然后将其转换为Torch脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

    torch.jit.script直接将Python函数(或者一个Python模块)通过python语法规则和编译转换为Torch脚本。torch.jit.script更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于RNN或者一些具有可变序列长度的模型,使用torch.jit.script会更为方便。

    在通常情况下,更应该倾向于使用torch.jit.trace而不是torch.jit.script

    在模型部署方面,onnx被大量使用。而导出onnx的过程,也是model进行torch.jit.trace的过程,因此这里我们把torch的trace做稍微详细一点的介绍。

    为了能够把模型编写的更能够被jit trace,需要对代码做一些妥协,例如:

    1.如果model中有DataParallel的子模块,或者model中有将tensors转换为numpy arrays,或者调用了opencv的函数等,这种情况下,model不是一个正确的在单个设备上、正确连接的graph,这种情况下,不管是使用torch.jit.script还是torch.jit.trace都不能trace出正确的TorchScript来。

    2.model的输入输出应该是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]的类型,而且在dict中的值,应该是同样的类型。但是对于model中间子模块的输入输出,可以是任意类型,例如dicts of Any, classes, kwargs以及python支持的都可以。对于model输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

    outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
    # torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
    adapter = TracingAdapter(model, inputs)  # 使用Adapter,将modelinputs包装为trace支持的类型
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功
    
    # Traced model的输出只能是tuple tensors类型:
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 再通过adapter转换为想要的输出类型
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    3.一些数值类型的问题。比如下面的代码片段

    import torch
    a=torch.tensor([1,2])
    print(type(a.size(0)))
    print(type(a.size()[0]))
    print(type(a.shape[0]))
    
    • 1
    • 2
    • 3
    • 4
    • 5

    在eager mode下,这几个返回值的类型都是int型。上面代码的输出为

    
    
    
    
    • 1
    • 2
    • 3

    但是在trace mode下,这几个表达式的返回值类型都是Tensor类型。因此,有些表达式使用不当,如果在trace过程中,一些shape表达式的返回值类型是int型,那么可能造成这块代码没有被trace。在代码中,可以通过使用torch.jit.is_tracing来检查这块代码在trace mode下有没有被执行。

    4.由于动态的control flow,造成模型没有被完整的trace。看下面的例子:

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.trace(f, torch.tensor(3))
    print(m.code)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出为

    def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)
    
    • 1
    • 2

    可以看到trace后的model只保留了一条分支。因此由于输入造成的dynamic的control flow,trace后容易出现错误。

    这种情况下,我们可以使用torch.jit.script来进行TorchScript的转换。

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.script(f)
    print(m.code)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    输出为

    def f(x: Tensor) -> Tensor:
      if bool(torch.gt(torch.sum(x), 0)):
        _0 = torch.sqrt(x)
      else:
        _0 = torch.square(x)
      return _0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    在大多数情况下,我们应该使用torch.jit.trace,但是像上面的这种dynamic control flow的情况,我们可以混合使用torch.jit.tracetorch.jit.script,在本文后面会进行阐述。

    另外在一些blog中,对于dynamic control flow的定义是有错误的,例如if x[0] == 4: x += 1是dynamic control flow,但是

    model: nn.Sequential = ...
    for m in model:
      x = m(x)
    
    • 1
    • 2
    • 3

    以及

    class A(nn.Module):
      backbone: nn.Module
      head: Optiona[nn.Module]
      def forward(self, x):
        x = self.backbone(x)
        if self.head is not None:
            x = self.head(x)
        return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    都不是dynamic control flow。dynamic control flow是由于对输入条件的判断造成的不同分支的执行。

    5.trace过程中,将变量trace成了常量。看下面一个例子

    import torch
    a, b = torch.rand(1), torch.rand(2)
    
    def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))
    
    print(torch.jit.trace(f1, a)(b))
    # 输出: tensor([0, 1])
    # 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的
    
    print(torch.jit.trace(f2, a)(b))
    # 输出:
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    # tensor([0])
    # 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。
    
    # 我们打印一下两者的区别
    print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
    # 输出
    # def f1(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1
     
    #  def f2(x: Tensor) -> Tensor:
    #   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _0
    
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    
    # 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    我们导出onnx的过程,也是进行torch.jit.trace的过程,在导出onnx的时候,有时候也会遇到

    TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    
    • 1

    这样的提示信息,这时候要检查一下代码中是不是有可能trace过程中,变量会被当做常量的情况,有可能会导致导出的onnx精度异常。

    除了len会导致trace错误,其他几个也会导致trace出现问题:

    • .item()会在trace过程中将tensors转为int/float

    • 任何将torch类型转为numpy/python类型的代码

    • 一些有问题的算子,例如advanced indexing

    1. torch.jit.trace不会对传入的device生效
    import torch
    def f(x):
        return torch.arange(x.shape[0], device=x.device)
    m = torch.jit.trace(f, torch.tensor([3]))
    print(m.code)
    # 输出
    # def f(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1
    print(m(torch.tensor([3]).cuda()).device)
    # 输出:device(type='cpu')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    trace不会对传入的cuda device生效。

    为了保证trace的正确,我们可以通过一下的一些方法来尽量保证trace后的模型不会出错:

    1.注意warnings信息。类似这样的TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个warning等级。

    2.做单元测试。需要验证一下eager mode的模型输出与trace后的模型输出是否一致。

    assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
    
    • 1

    3.避免一些特殊的情况。例如下面的代码

    if x.numel() > 0:
      output = self.layers(x)
    else:
      output = torch.zeros((0, C, H, W))  # 会创建一个空的输出
    
    • 1
    • 2
    • 3
    • 4

    避免一些特殊情况比如空的输入输出之类的。

    4.注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:

    • 使用torch.size(0)来代替len(tensor),因为torch.size(0)返回的是Tensorlen(tensor)返回的是int。对于自定义类,实现一个.size方法或者使用.__len__()方法来代替len(),例如这个例子
    • 不要使用int()或者torch.as_tensor来转换size的类型,因为这些操作也会被视为常量。

    5.混合tracing和scripting方法。可以使用torch.jit.script来转换一些torch.jit.trace不能搞定的小的代码片段,混合使用tracing和scripting,基本可以解决所有的问题。

    混合使用tracing和scripting

    tracing和scripting都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用torch.jit.trace,必要时才使用torch.jit.script

    1.在使用torch.jit.trace时,使用@script_if_tracing装饰器可以让被装饰的函数使用scripting方式进行编译。

    def forward(self, ...):
      # ... some forward logic
      @torch.jit.script_if_tracing
      def _inner_impl(x, y, z, flag: bool):
          # use control flow, etc.
          return ...
      output = _inner_impl(x, y, z, flag)
      # ... other forward logic
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    但是使用@script_if_tracing时,需要保证函数中没有pytorch的modules,如果有的话,需要做一些修改,例如下面的:

    # 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
      output = self.layers(x)
    else:
      # Create empty outputs
      output = torch.zeros(...)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    这里需要做如下修改:

    # 需要将self.layers移出if判断,这时候可以用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
    else:
      # Create empty inputs
      x = torch.zeros(...)
    # 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
    output = self.layers(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    2.合并多次tracing的结果

    使用torch.jit.script生成的模型相比使用torch.jit.trace有两个好处:

    • 可以使用条件控制流,例如模型中使用一个bool值来控制forward的flow,在traced modules里面是不支持的
    • 使用traced module,只能有一个forward()函数,但是使用scripted module,可以有多个前向计算的函数
    class Detector(nn.Module):
      do_keypoint: bool
    
      def forward(self, img):
          box = self.predict_boxes(img)
          if self.do_keypoint:
              kpts = self.predict_keypoint(img, box)
    
      @torch.jit.export
      def predict_boxes(self, img): pass
    
      @torch.jit.export
      def predict_keypoint(self, img, box): pass
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    对于这种有bool值的控制流,除了使用script,还可以多次进行trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    
    • 1
    • 2

    然后将他们的weight复制一遍,并合并两次trace的结果

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    trace和script的性能

    tracing总是会比scripting生成一样或者更简单的计算图,因此性能会更好一些。因为scripting会完整的表达python代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

    class A(nn.Module):
      def forward(self, x1, x2, x3):
        z = [0, 1, 2]
        xs = [x1, x2, x3]
        for k in z: x1 += xs[k]
        return x1
    model = A()
    print(torch.jit.script(model).code)
    # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
    #   z = [0, 1, 2]
    #   xs = [x1, x2, x3]
    #   x10 = x1
    #   for _0 in range(torch.len(z)):
    #     k = z[_0]
    #     x10 = torch.add_(x10, xs[k])
    #   return x10
    print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
    # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
    #   x10 = torch.add_(x1, x1)
    #   x11 = torch.add_(x10, x2)
    #   return torch.add_(x11, x3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    总结

    tracing具有明显的局限性:这篇文章的大部分篇幅都在谈论tracing的局限性以及如何解决这些问题。实际上,这正是tracing的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

    相反,scripting更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复scripting的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

    tracing和scripting都会影响代码的编写方式,但tracing因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

    • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
    • 它需要修改一些代码才能通用(例如在tracing时添加一些scripting),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。
  • 相关阅读:
    关于kafka常见名词解释,你了解多少?
    关于最近项目数字前端FLOW的一些总结
    el-table中保留分页选中
    【linux】Vmware安装CentOS 7.9
    【前端技巧】css篇
    Java设计模式——工厂模式讲解以及在JDK中源码分析
    okhttp添加公共参数
    聊聊jedis连接池参数配置
    数风流人物还看今朝|前后端分离微服务项目常用中间件以及指令
    全志R128芯片应用开发案例——驱动 WS2812 流水灯
  • 原文地址:https://blog.csdn.net/bornfree5511/article/details/133980467