使用自定义 ONNX 算子导出 PyTorch 模型
本文档解释了使用自定义 ONNX Runtime 算子导出 PyTorch 模型的过程。目标是导出包含 ONNX 中不支持的算子的 PyTorch 模型,并扩展 ONNX Runtime 以支持这些自定义算子。
目录
导出内置 Contrib 算子
“Contrib 算子”指的是大多数 ORT 包中内置的一组自定义算子。所有 Contrib 算子的符号函数应定义在 pytorch_export_contrib_ops.py 中。
要使用这些 Contrib 算子进行导出,请在调用 torch.onnx.export()
之前调用 pytorch_export_contrib_ops.register()
。例如
from onnxruntime.tools import pytorch_export_contrib_ops
import torch
pytorch_export_contrib_ops.register()
torch.onnx.export(...)
导出自定义算子
要导出非 Contrib 算子或未包含在 pytorch_export_contrib_ops
中的自定义算子,需要编写并注册一个自定义算子符号函数。
我们以 Inverse 算子为例
from torch.onnx import register_custom_op_symbolic
def my_inverse(g, self):
return g.op("com.microsoft::Inverse", self)
# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
register_custom_op_symbolic('::inverse', my_inverse, 1)
<namespace>
是 torch 算子名称的一部分。对于标准 torch 算子,可以省略 namespace。
应将 com.microsoft
用作 ONNX Runtime 算子的自定义 opset 域。您可以在算子注册期间选择自定义 opset 版本。
有关编写符号函数的更多信息,请参阅 torch.onnx 文档。
使用自定义算子扩展 ONNX Runtime
下一步是在 ONNX Runtime 中添加算子 schema 和 kernel 实现。详细信息请参阅自定义算子。
端到端测试:导出和运行
在自定义算子在导出器中注册并在 ONNX Runtime 中实现后,您应该能够将其导出并使用 ONNX Runtime 运行。
您可以在下方找到将 inverse 算子作为模型一部分导出和运行的示例脚本。
导出的模型包含 ONNX 标准算子和自定义算子的组合。
此测试还比较了 PyTorch 模型输出与 ONNX Runtime 输出,以测试算子导出和实现。
import io
import numpy
import onnxruntime
import torch
class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x
x = torch.randn(3, 3)
# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x,), f)
model = CustomInverse()
pt_outputs = model(x)
# Run the exported model with ONNX Runtime
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
ort_outputs = ort_sess.run(None, ort_inputs)
# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)
默认情况下,自定义 opset 的 opset 版本将设置为 1
。如果您想将自定义算子导出到更高的 opset 版本,可以在调用导出 API 时使用 custom_opsets argument
指定自定义 opset 域和版本。请注意,这与默认 ONNX
域关联的 opset 版本不同。
torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})
请注意,您可以将自定义算子导出到任何大于等于注册时使用的 opset 版本。