torch.jit.trace与torch.jit.script
Published:
Script mode通过torch.jit.trace
或者torch.jit.script
来调用。这两个函数都是将python代码转换为TorchScript的两种不同的方法。
torch.jit.trace
将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个PyTorch模型,torch.jit.trace
会跟踪此input在model中的计算过程,然后将其转换为Torch脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。
torch.jit.script
直接将Python函数(或者一个Python模块)通过python语法规则和编译转换为Torch脚本。torch.jit.script
更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于RNN或者一些具有可变序列长度的模型,使用torch.jit.script
会更为方便。
在通常情况下,更应该倾向于使用torch.jit.trace
而不是torch.jit.script
。
在模型部署方面,onnx被大量使用。而导出onnx的过程,也是model进行torch.jit.trace
的过程,因此这里我们把torch的trace做稍微详细一点的介绍。
为了能够把模型编写的更能够被jit trace,需要对代码做一些妥协,例如:
1.如果model中有DataParallel
的子模块,或者model中有将tensors转换为numpy arrays,或者调用了opencv的函数等,这种情况下,model不是一个正确的在单个设备上、正确连接的graph,这种情况下,不管是使用torch.jit.script
还是torch.jit.trace
都不能trace出正确的TorchScript来。
2.model的输入输出应该是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]
的类型,而且在dict中的值,应该是同样的类型。但是对于model中间子模块的输入输出,可以是任意类型,例如dicts of Any, classes, kwargs以及python支持的都可以。对于model输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:
outputs = model(inputs) # inputs和outputs是python的类型, 例如dictsor classes
# torch.jit.trace(model, inputs) # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
adapter = TracingAdapter(model, inputs) # 使用Adapter,将modelinputs包装为trace支持的类型
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # 现在以trace成功
# Traced model的输出只能是tuple tensors类型:
flattened_outputs = traced(*adapter.flattened_inputs)
# 再通过adapter转换为想要的输出类型
new_outputs = adapter.outputs_schema(flattened_outputs)
3.一些数值类型的问题。比如下面的代码片段
import torch
a=torch.tensor([1,2])
print(type(a.size(0)))
print(type(a.size()[0]))
print(type(a.shape[0]))
在eager mode下,这几个返回值的类型都是int型。上面代码的输出为
<class 'int'>
<class 'int'>
<class 'int'>
但是在trace mode下,这几个表达式的返回值类型都是Tensor
类型。因此,有些表达式使用不当,如果在trace过程中,一些shape表达式的返回值类型是int型,那么可能造成这块代码没有被trace。在代码中,可以通过使用torch.jit.is_tracing
来检查这块代码在trace mode下有没有被执行。
4.由于动态的control flow,造成模型没有被完整的trace。看下面的例子:
import torch
def f(x):
return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.trace(f, torch.tensor(3))
print(m.code)
输出为
def f(x: Tensor) -> Tensor:
return torch.sqrt(x)
可以看到trace后的model只保留了一条分支。因此由于输入造成的dynamic的control flow,trace后容易出现错误。
这种情况下,我们可以使用torch.jit.script
来进行TorchScript的转换。
import torch
def f(x):
return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
m = torch.jit.script(f)
print(m.code)
输出为
def f(x: Tensor) -> Tensor:
if bool(torch.gt(torch.sum(x), 0)):
_0 = torch.sqrt(x)
else:
_0 = torch.square(x)
return _0
在大多数情况下,我们应该使用torch.jit.trace
,但是像上面的这种dynamic control flow的情况,我们可以混合使用torch.jit.trace
和torch.jit.script
,在本文后面会进行阐述。
另外在一些blog中,对于dynamic control flow的定义是有错误的,例如if x[0] == 4: x += 1
是dynamic control flow,但是
model: nn.Sequential = ...
for m in model:
x = m(x)
以及
class A(nn.Module):
backbone: nn.Module
head: Optiona[nn.Module]
def forward(self, x):
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
都不是dynamic control flow。dynamic control flow是由于对输入条件的判断造成的不同分支的执行。
5.trace过程中,将变量trace成了常量。看下面一个例子
import torch
a, b = torch.rand(1), torch.rand(2)
def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))
print(torch.jit.trace(f1, a)(b))
# 输出: tensor([0, 1])
# 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的
print(torch.jit.trace(f2, a)(b))
# 输出:
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
# tensor([0])
# 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。
# 我们打印一下两者的区别
print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
# 输出
# def f1(x: Tensor) -> Tensor:
# _0 = ops.prim.NumToTensor(torch.size(x, 0))
# _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _1
# def f2(x: Tensor) -> Tensor:
# _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _0
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
# 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
我们导出onnx的过程,也是进行torch.jit.trace的过程,在导出onnx的时候,有时候也会遇到
TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
这样的提示信息,这时候要检查一下代码中是不是有可能trace过程中,变量会被当做常量的情况,有可能会导致导出的onnx精度异常。
除了len
会导致trace错误,其他几个也会导致trace出现问题:
.item()
会在trace过程中将tensors转为int/float任何将torch类型转为numpy/python类型的代码
一些有问题的算子,例如advanced indexing
- torch.jit.trace不会对传入的device生效
import torch
def f(x):
return torch.arange(x.shape[0], device=x.device)
m = torch.jit.trace(f, torch.tensor([3]))
print(m.code)
# 输出
# def f(x: Tensor) -> Tensor:
# _0 = ops.prim.NumToTensor(torch.size(x, 0))
# _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
# return _1
print(m(torch.tensor([3]).cuda()).device)
# 输出:device(type='cpu')
trace不会对传入的cuda device生效。
为了保证trace的正确,我们可以通过一下的一些方法来尽量保证trace后的模型不会出错:
1.注意warnings信息。类似这样的TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个warning等级。
2.做单元测试。需要验证一下eager mode的模型输出与trace后的模型输出是否一致。
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
3.避免一些特殊的情况。例如下面的代码
if x.numel() > 0:
output = self.layers(x)
else:
output = torch.zeros((0, C, H, W)) # 会创建一个空的输出
避免一些特殊情况比如空的输入输出之类的。
4.注意shape的使用。前面提到,tensor.size()
在trace过程中会返回Tensor
类型的数据,Tensor
类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:
- 使用
torch.size(0)
来代替len(tensor)
,因为torch.size(0)
返回的是Tensor
,len(tensor)
返回的是int
。对于自定义类,实现一个.size
方法或者使用.__len__()
方法来代替len()
,例如这个例子 - 不要使用
int()
或者torch.as_tensor
来转换size的类型,因为这些操作也会被视为常量。
5.混合tracing和scripting方法。可以使用torch.jit.script
来转换一些torch.jit.trace
不能搞定的小的代码片段,混合使用tracing和scripting,基本可以解决所有的问题。
混合使用tracing和scripting
tracing和scripting都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用torch.jit.trace
,必要时才使用torch.jit.script
。
1.在使用torch.jit.trace
时,使用@script_if_tracing
装饰器可以让被装饰的函数使用scripting方式进行编译。
def forward(self, ...):
# ... some forward logic
@torch.jit.script_if_tracing
def _inner_impl(x, y, z, flag: bool):
# use control flow, etc.
return ...
output = _inner_impl(x, y, z, flag)
# ... other forward logic
但是使用@script_if_tracing
时,需要保证函数中没有pytorch的modules,如果有的话,需要做一些修改,例如下面的:
# 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
if x.numel() > 0:
x = preprocess(x)
output = self.layers(x)
else:
# Create empty outputs
output = torch.zeros(...)
这里需要做如下修改:
# 需要将self.layers移出if判断,这时候可以用@script_if_tracing
if x.numel() > 0:
x = preprocess(x)
else:
# Create empty inputs
x = torch.zeros(...)
# 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
output = self.layers(x)
2.合并多次tracing的结果
使用torch.jit.script
生成的模型相比使用torch.jit.trace
有两个好处:
- 可以使用条件控制流,例如模型中使用一个bool值来控制forward的flow,在traced modules里面是不支持的
- 使用traced module,只能有一个forward()函数,但是使用scripted module,可以有多个前向计算的函数
class Detector(nn.Module):
do_keypoint: bool
def forward(self, img):
box = self.predict_boxes(img)
if self.do_keypoint:
kpts = self.predict_keypoint(img, box)
@torch.jit.export
def predict_boxes(self, img): pass
@torch.jit.export
def predict_keypoint(self, img, box): pass
对于这种有bool值的控制流,除了使用script,还可以多次进行trace,然后将结果合并。
det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
然后将他们的weight复制一遍,并合并两次trace的结果
det2.submodule.weight = det1.submodule.weight
class Wrapper(nn.ModuleList):
def forward(self, img, do_keypoint: bool):
if do_keypoint:
return self[0](img)
else:
return self[1](img)
exported = torch.jit.script(Wrapper([det1, det2]))
trace和script的性能
tracing总是会比scripting生成一样或者更简单的计算图,因此性能会更好一些。因为scripting会完整的表达python代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:
class A(nn.Module):
def forward(self, x1, x2, x3):
z = [0, 1, 2]
xs = [x1, x2, x3]
for k in z: x1 += xs[k]
return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# z = [0, 1, 2]
# xs = [x1, x2, x3]
# x10 = x1
# for _0 in range(torch.len(z)):
# k = z[_0]
# x10 = torch.add_(x10, xs[k])
# return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
# x10 = torch.add_(x1, x1)
# x11 = torch.add_(x10, x2)
# return torch.add_(x11, x3)
总结
tracing具有明显的局限性:这篇文章的大部分篇幅都在谈论tracing的局限性以及如何解决这些问题。实际上,这正是tracing的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。
相反,scripting更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复scripting的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。
tracing和scripting都会影响代码的编写方式,但tracing因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:
- 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
- 它需要修改一些代码才能通用(例如在tracing时添加一些scripting),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。
reference: https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/