import torch
from torch import nn
class MyModule(nn.Module):
def __init__(self, return_b=False):
super().__init__()
self.return_b = return_b
def forward(self, x):
a = x + 2
if self.return_b: #属于静态控制
b = x + 3
return a, b
return a
model = MyModule(return_b=True)
# Will work 成功
traced = torch.jit.trace(model, (torch.randn(10, ), ))
# Will fail 失败
scripted = torch.jit.script(model)
model: nn.Sequential = ...
for m in model: # 动态控制
x = m(x)
outputs = model(inputs) # inputs/outputs are rich structure
# torch.jit.trace(model, inputs) # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model
# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs)
def f(x):
return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code) # 可以打印出trace的情况
#--------------------------------------------
def f(x: Tensor) -> Tensor:
return torch.sqrt(x)
import torch
a, b = torch.rand(1), torch.rand(2)
print(a,b)
def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))
result = torch.jit.trace(f1, a)(b)
print(result)
result =torch.jit.trace(f2, a)(b) # TracerWarning
print(result) #
print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
错误示例:获取设备
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
if x.numel() > 0:
output = self.layers(x)
else:
output = torch.zeros((0, C, H, W)) # Create empty outputs