ONNX Runtime 中的图优化
ONNX Runtime 提供了各种图优化来提高性能。图优化本质上是图级别的转换,范围从小型的图简化和节点消除到更复杂的节点融合和布局优化。
图优化根据其复杂性和功能分为几个类别(或级别)。它们可以在线或离线执行。在在线模式下,优化在执行推理之前完成,而在离线模式下,运行时将优化的图保存到磁盘。 ONNX Runtime 提供了 Python、C#、C++ 和 C API 来启用不同的优化级别,并选择离线与在线模式。
下面我们提供有关优化级别、在线/离线模式以及控制它们的各种 API 的详细信息。
目录
图优化级别
图优化分为三个级别
- 基本
- 扩展
- 布局优化
一个级别中的优化是在应用前一级别的优化之后执行的(例如,扩展优化是在应用基本优化之后应用的)。
默认情况下,所有优化均已启用。
基本图优化
这些是语义保留的图重写,可删除冗余节点和冗余计算。它们在图分区之前运行,因此适用于所有执行提供程序。可用的基本图优化如下:
-
常量折叠:静态计算图中仅依赖于常量初始值设定项的部分。这消除了在运行时计算它们的需要。
- 冗余节点消除:删除所有冗余节点而不更改图结构。当前支持以下此类优化:
- 恒等消除
- 切片消除
- Unsqueeze 消除
- Dropout 消除
- 语义保留节点融合:将多个节点融合/折叠为单个节点。例如,Conv Add 融合将 Add 运算符折叠为 Conv 运算符的偏置。当前支持以下此类优化:
- Conv Add 融合
- Conv Mul 融合
- Conv BatchNorm 融合
- Relu Clip 融合
- Reshape 融合
扩展图优化
这些优化包括复杂的节点融合。它们在图分区之后运行,并且仅应用于分配给 CPU 或 CUDA 或 ROCm 执行提供程序的节点。可用的扩展图优化如下:
优化 | 执行提供程序 | 注释 |
---|---|---|
GEMM 激活融合 | CPU | |
Matmul Add 融合 | CPU | |
Conv 激活融合 | CPU | |
GELU 融合 | CPU、CUDA、ROCm | |
层归一化融合 | CPU、CUDA、ROCm | |
BERT 嵌入层融合 | CPU、CUDA、ROCm | 融合 BERT 嵌入层、层归一化和注意力掩码长度 |
注意力融合* | CPU、CUDA、ROCm | |
跳过层归一化融合 | CPU、CUDA、ROCm | 融合全连接层的偏置、跳过连接和层归一化 |
Bias GELU 融合 | CPU、CUDA、ROCm | 融合全连接层的偏置和 GELU 激活 |
GELU 近似* | CUDA、ROCm | 默认禁用。使用 kOrtSessionOptionsEnableGeluApproximation 启用 |
Approximations (click to expand)
Approximations (click to expand)
为了优化 BERT 的性能,GELU 近似和注意力融合在 CUDA 和 ROCm 执行提供程序中使用了近似。根据我们的评估,对精度的影响可以忽略不计:SQuAD v1.1 上 BERT 模型的 F1 分数几乎相同(87.05 vs 87.03)。
布局优化
这些优化更改适用节点的数据布局,以实现更高的性能改进。它们在图分区之后运行,并且仅应用于分配给 CPU 执行提供程序的节点。可用的布局优化如下:
- NCHWc 优化器:通过使用 NCHWc 布局而不是 NCHW 布局来优化图。
在线/离线模式
所有优化都可以在线或离线执行。在在线模式下,当初始化推理会话时,我们还会在执行模型推理之前应用所有已启用的图优化。每次启动会话时都应用所有优化可能会增加模型启动时间的开销(特别是对于复杂模型),这在生产场景中可能至关重要。这就是离线模式可以带来很多好处的地方。在离线模式下,在执行图优化之后,ONNX Runtime 将生成的模型序列化到磁盘。随后,我们可以通过使用已优化的模型并禁用所有优化来减少启动时间。
注意:
- 在离线模式下运行时,请确保使用与模型推理将在其上运行的目标机器完全相同的选项(例如,执行提供程序、优化级别)和硬件(例如,您不能在仅配备 CPU 的机器上运行为 GPU 执行提供程序预优化的模型)。
- 当启用布局优化时,离线模式只能在与保存离线模型时环境兼容的硬件上使用。例如,如果模型具有针对 AVX2 优化的布局,则离线模型将需要支持 AVX2 的 CPU。
用法
级别
ONNX Runtime 定义了 GraphOptimizationLevel
枚举,以确定将启用上述哪个优化级别。选择一个级别将启用该级别的优化,以及所有先前级别的优化。例如,启用扩展优化也会启用基本优化。这些级别到枚举的映射如下:
- GraphOptimizationLevel::ORT_DISABLE_ALL -> 禁用所有优化
- GraphOptimizationLevel::ORT_ENABLE_BASIC -> 启用基本优化
- GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 启用基本和扩展优化
- GraphOptimizationLevel::ORT_ENABLE_ALL -> 启用所有可用的优化,包括布局优化
离线模式
要启用将优化模型序列化到磁盘,请设置 SessionOptions 选项 optimized_model_filepath
。
Python API 示例
import onnxruntime as rt
sess_options = rt.SessionOptions()
# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"
session = rt.InferenceSession("<model_path>", sess_options)
C API 示例
const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
OrtEnv* env;
g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
OrtSessionOptions* session_options;
g_ort->CreateSessionOptions(&session_options)
// Set graph optimization level
g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);
// To enable model serialization after graph optimization set this
const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);
OrtSession* session;
const ORTCHAR_T* model_path = ORT_TSTR("model_path");
g_ort->CreateSession(env, model_path, session_option, &session);
C# API 示例
SessionOptions so = new SessionOptions();
// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"
var session = new InferenceSession(modelPath, so);
C++ API 示例
Ort::SessionOptions session_options;
// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");
auto session_ = Ort::Session(env, "model_file_path", session_options);