使用自定义 ONNX 算子导出 PyTorch 模型

本文档解释了使用自定义 ONNX Runtime 算子导出 PyTorch 模型的过程。目标是导出包含 ONNX 中不支持的算子的 PyTorch 模型,并扩展 ONNX Runtime 以支持这些自定义算子。

目录

导出内置 Contrib 算子

“Contrib 算子”指的是大多数 ORT 包中内置的一组自定义算子。所有 Contrib 算子的符号函数应定义在 pytorch_export_contrib_ops.py 中。

要使用这些 Contrib 算子进行导出,请在调用 torch.onnx.export() 之前调用 pytorch_export_contrib_ops.register()。例如

from onnxruntime.tools import pytorch_export_contrib_ops
import torch

pytorch_export_contrib_ops.register()
torch.onnx.export(...)

导出自定义算子

要导出非 Contrib 算子或未包含在 pytorch_export_contrib_ops 中的自定义算子,需要编写并注册一个自定义算子符号函数。

我们以 Inverse 算子为例

from torch.onnx import register_custom_op_symbolic

def my_inverse(g, self):
    return g.op("com.microsoft::Inverse", self)

# register_custom_op_symbolic('<namespace>::inverse', my_inverse, <opset_version>)
register_custom_op_symbolic('::inverse', my_inverse, 1)

<namespace> 是 torch 算子名称的一部分。对于标准 torch 算子,可以省略 namespace。

应将 com.microsoft 用作 ONNX Runtime 算子的自定义 opset 域。您可以在算子注册期间选择自定义 opset 版本。

有关编写符号函数的更多信息,请参阅 torch.onnx 文档

使用自定义算子扩展 ONNX Runtime

下一步是在 ONNX Runtime 中添加算子 schema 和 kernel 实现。详细信息请参阅自定义算子

端到端测试:导出和运行

在自定义算子在导出器中注册并在 ONNX Runtime 中实现后,您应该能够将其导出并使用 ONNX Runtime 运行。

您可以在下方找到将 inverse 算子作为模型一部分导出和运行的示例脚本。

导出的模型包含 ONNX 标准算子和自定义算子的组合。

此测试还比较了 PyTorch 模型输出与 ONNX Runtime 输出,以测试算子导出和实现。

import io
import numpy
import onnxruntime
import torch


class CustomInverse(torch.nn.Module):
    def forward(self, x):
        return torch.inverse(x) + x

x = torch.randn(3, 3)

# Export model to ONNX
f = io.BytesIO()
torch.onnx.export(CustomInverse(), (x,), f)

model = CustomInverse()
pt_outputs = model(x)

# Run the exported model with ONNX Runtime
ort_sess = onnxruntime.InferenceSession(f.getvalue())
ort_inputs = dict((ort_sess.get_inputs()[i].name, input.cpu().numpy()) for i, input in enumerate((x,)))
ort_outputs = ort_sess.run(None, ort_inputs)

# Validate PyTorch and ONNX Runtime results
numpy.testing.assert_allclose(pt_outputs.cpu().numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)

默认情况下,自定义 opset 的 opset 版本将设置为 1。如果您想将自定义算子导出到更高的 opset 版本,可以在调用导出 API 时使用 custom_opsets argument 指定自定义 opset 域和版本。请注意,这与默认 ONNX 域关联的 opset 版本不同。

torch.onnx.export(CustomInverse(), (x,), f, custom_opsets={"com.microsoft": 5})

请注意,您可以将自定义算子导出到任何大于等于注册时使用的 opset 版本。