使用自定义 ONNX 算子导出 PyTorch 模型
本文档解释了使用自定义 ONNX Runtime 算子导出 PyTorch 模型的过程。目的是导出包含 ONNX 不支持的算子的 PyTorch 模型,并扩展 ONNX Runtime 以支持这些自定义算子。
目录
导出内置 Contrib 算子
“Contrib 算子”指的是内置于大多数 ORT 软件包中的一组自定义算子。所有 Contrib 算子的符号函数应在 pytorch_export_contrib_ops.py 中定义。
要使用这些 Contrib 算子进行导出,请在调用 torch.onnx.export()
之前调用 pytorch_export_contrib_ops.register()
。例如:
from onnxruntime.tools import pytorch_export_contrib_ops
import torch
pytorch_export_contrib_ops.register()
torch.onnx.export(...)
导出自定义算子
要导出非 Contrib 算子或未包含在 pytorch_export_contrib_ops
中的自定义算子,需要编写并注册自定义算子的符号函数。
我们以 Inverse 算子为例:
from torch.onnx import register_custom_op_symbolic
def my_inverse(g, self):
return g.op("com.microsoft::Inverse", self)
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
register_custom_op_symbolic('::inverse', my_inverse, 1)
<namespace>
是 torch 算子名称的一部分。对于标准 torch 算子,命名空间可以省略。
com.microsoft
应作为 ONNX Runtime 算子的自定义 opset 域使用。您可以在算子注册期间选择自定义 opset 版本。
有关编写符号函数的更多信息,请参阅 torch.onnx 文档。
使用自定义算子扩展 ONNX Runtime
下一步是在 ONNX Runtime 中添加算子 Schema 和内核实现。详见自定义算子。
端到端测试:导出和运行
一旦自定义算子在导出器中注册并在 ONNX Runtime 中实现,您应该能够导出它并使用 ONNX Runtime 运行它。
下面是一个示例脚本,用于导出 Inverse 算子并作为模型的一部分运行。
导出的模型包括 ONNX 标准算子和自定义算子的组合。
此测试还比较了 PyTorch 模型与 ONNX Runtime 输出的结果,以测试算子导出和实现。
import io
import numpy
import onnxruntime
import torch
class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x
x = torch.randn(3, 3)
# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x,), f)
model = CustomInverse()
pt_outputs = model(x)
# Run the exported model with ONNX Runtime
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
ort_outputs = ort_sess.run(None, ort_inputs)
# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)
默认情况下,自定义 opset 的 opset 版本将设置为 1
。如果您想将自定义算子导出到更高的 opset 版本,可以在调用导出 API 时使用 custom_opsets
参数指定自定义 opset 域和版本。请注意,这与默认 ONNX
域关联的 opset 版本不同。
torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})
请注意,您可以将自定义算子导出到任何大于等于注册时使用的 opset 版本。