• torch.jit.trace与torch.jit.script的区别


    术语

    1. Tochscript:狭义概念导出图形的表示/格式;广义概念为导出模型的方法;
    2. (Torch)Scriptable:可以用torch.jit.script导出模型
    3. Traceable:可以用torch.jit.trace导出模型

    什么时候用torch.jit.trace(结论:首选)

    1. torch.jit.trace一种导出方法;它运行具有某些张量输入的模型,并“跟踪/记录”所有执行到图形中的操作。
    2. 在模型内部的数据类型只有张量,且没有for if while等控制流,选择torch.jit.trace
    3. 支持python的预处理和动态行为;
    4. torch.jit.trace编译function并返回一个可执行文件,该可执行文件将使用即时编译进行优化。
    5. 大项目优先选择torch.jit.trace,特别是是图像检测和分割的算法;

    优点

    1. 不会损害代码质量;
    2. 2.它的主要限制可以通过与torch.jit.script混合来解决

    什么时候用torch.jit.script(结论:必要时)

    1. 定义:一种模型导出方法,其实编译python的模型源码,得到可执行的图;
    2. 在模型内部的数据类型只有张量,且没有for if while等控制流,也可以选择torch.jit.script
    3. 不支持python的预处理和动态行为;
    4. 必须做一下类型标注;
    5. torch.jit.script在编译function或 nn.Module 脚本将检查源代码,使用 TorchScript 编译器将其编译为 TorchScript 代码。

    错误举例

    import torch
    from torch import nn
    
    
    class MyModule(nn.Module):
        def __init__(self, return_b=False):
            super().__init__()
            self.return_b = return_b
    
        def forward(self, x):
            a = x + 2
            if self.return_b:  #属于静态控制
                b = x + 3
                return a, b
            return a
    
    
    model = MyModule(return_b=True)
    
    # Will work  成功
    traced = torch.jit.trace(model, (torch.randn(10, ), ))
    
    # Will fail 失败
    scripted = torch.jit.script(model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 总结:控制流是静态的,torch.jit.trace将正常工作

    动态控制

    1. if x[0] == 4: x += 1 is a dynamic control flow.
    model: nn.Sequential = ...
    for m in model:  # 动态控制
      x = m(x) 
    
    • 1
    • 2
    • 3

    输入和输出有丰富类型的模型需要格外注意

    outputs = model(inputs)   # inputs/outputs are rich structure
    # torch.jit.trace(model, inputs)  # FAIL! unsupported format
    adapter = TracingAdapter(model, inputs)
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # Can now trace the model
    
    # Traced model can only produce flattened outputs (tuple of tensors):
    flattened_outputs = traced(*adapter.flattened_inputs)
    # Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    QA

      1. JIT要求python的代码要是低级的;详情 因为更多动态高级的python语法,jit不支持.具体哪些支持哪些没支持官方也没有详细的列表; JIT should not force users to write ugly code #48108
      1. 错误示例:动态控制流:对于动态控制流torch.jit.trace只会编译一个分支,在其他分支处理的时候会报错;
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    m = torch.jit.trace(f, torch.tensor(3))
    print(m.code) # 可以打印出trace的情况
    #--------------------------------------------
    def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
      1. 错误示例:将变量视为常量
    import torch
    
    a, b = torch.rand(1), torch.rand(2)
    print(a,b)
    
    def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))
    result = torch.jit.trace(f1, a)(b)
    print(result)
    
    result =torch.jit.trace(f2, a)(b) # TracerWarning
    print(result) #
    
    print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • cuX01R

    • 错误示例:获取设备

    解决错误的方法

      1. 严格消除警告信息,才C++运行的时候会报错
      1. 局部单元测试
      • 单元测试一样要做在导出模型后,这样避免在应用模型的时候(C++运行)出错;
    assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
    
    • 1
      1. 避免非必要的动态控制,例如:
    if x.numel() > 0:
      output = self.layers(x)
    else:
      output = torch.zeros((0, C, H, W))  # Create empty outputs
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    CPU+GPU掌舵AI大算力时代,中国企业能否从巨头碗中分一杯羹?
    Openssl教程
    26. 【Linux教程】Linux 查看环境变量
    【Redux】Redux 基本使用
    【揭秘】那些你可能没发现的高质量免费学习资源网站
    【Unity】零基础实现塔防游戏中敌人沿固定路径移动的功能
    【机器学习基础】正则化
    怎么练习黑客技术不会犯法?这6个网站也许可以帮到你,收藏就完事了
    【首测】两款OpenCV 人工智能深度相机OAK PoE
    基于webapi的websocket聊天室(番外二)
  • 原文地址:https://blog.csdn.net/weixin_32393347/article/details/125899693