TensorFlow 模型转换为 ONNX 入门

TensorFlow 模型(包括 keras 和 TFLite 模型)可以使用 tf2onnx 工具转换为 ONNX。

本教程的完整代码可在此处获取:此处

安装

首先在已安装 TensorFlow 的 Python 环境中安装 tf2onnx。

pip install tf2onnx (稳定版)

pip install git+https://github.com/onnx/tensorflow-onnx (GitHub 最新版)

转换模型

Keras 模型和 tf 函数

Keras 模型和 tf 函数可以直接在 Python 中转换

import tensorflow as tf
import tf2onnx
import onnx

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(4, activation="relu"))

input_signature = [tf.TensorSpec([3, 3], tf.float32, name='x')]
# Use from_function for tf functions
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)
onnx.save(onnx_model, "dst/path/model.onnx")

请参阅 Python API 参考 获取完整文档。

SavedModel

使用以下命令转换 TensorFlow Saved Model

python -m tf2onnx.convert --saved-model path/to/savedmodel --output dst/path/model.onnx --opset 13

path/to/savedmodel 应该是包含 saved_model.pb 的目录路径

请参阅 CLI 参考 获取完整文档。

TFLite

tf2onnx 支持转换 tflite 模型。

python -m tf2onnx.convert --tflite path/to/model.tflite --output dst/path/model.onnx --opset 13

注意:Opset 编号

如果使用的 ONNX opset 过低,某些 TensorFlow ops 将无法转换。使用与您的应用程序兼容的最大 opset。 有关完整的转换说明,请参阅 tf2onnx README

验证转换后的模型

使用以下命令安装 onnxruntime

pip install onnxruntime

使用以下模板在 Python 中测试您的模型

import onnxruntime as ort
import numpy as np

# Change shapes and types to match model
input1 = np.zeros((1, 100, 100, 3), np.float32)

# Start from ORT 1.10, ORT requires explicitly setting the providers parameter if you want to use execution providers
# other than the default CPU provider (as opposed to the previous behavior of providers getting set/registered by default
# based on the build flags) when instantiating InferenceSession.
# Following code assumes NVIDIA GPU is available, you can specify other execution providers or don't include providers parameter
# to use default CPU provider.
sess = ort.InferenceSession("dst/path/model.onnx", providers=["CUDAExecutionProvider"])

# Set first argument of sess.run to None to use all model outputs in default order
# Input/output names are printed by the CLI and can be set with --rename-inputs and --rename-outputs
# If using the python API, names are determined from function arg names or TensorSpec names.
results_ort = sess.run(["output1", "output2"], {"input1": input1})

import tensorflow as tf
model = tf.saved_model.load("path/to/savedmodel")
results_tf = model(input1)

for ort_res, tf_res in zip(results_ort, results_tf):
    np.testing.assert_allclose(ort_res, tf_res, rtol=1e-5, atol=1e-5)

print("Results match")

转换失败

如果您的模型转换失败,请阅读我们的 README故障排除指南。如果仍然失败,请随时 在 GitHub 上打开一个 issue。欢迎为 tf2onnx 做出贡献!

后续步骤