ONNX Runtime 中的图优化
ONNX Runtime 提供了各种图优化来提高性能。图优化本质上是图级别的转换,范围从小型图简化和节点消除到更复杂的节点融合和布局优化。
图优化根据其复杂性和功能分为几个类别(或级别)。它们可以以在线或离线模式执行。在在线模式下,优化在执行推理之前完成;而在离线模式下,运行时将优化后的图保存到磁盘。ONNX Runtime 提供了 Python、C#、C++ 和 C API,可以启用不同的优化级别,并选择离线或在线模式。
下面我们详细介绍优化级别、在线/离线模式以及控制它们的各种 API。
目录
图优化级别
图优化分为三个级别
- 基本
- 扩展
- 布局优化
一个级别所属的优化在前一个级别的优化应用后执行(例如,扩展优化在基本优化应用后执行)。
所有优化默认启用。
基本图优化
这些是保留语义的图重写,用于移除冗余节点和冗余计算。它们在图划分之前运行,因此适用于所有执行提供程序。可用的基本图优化如下:
-
常量折叠:静态计算图中仅依赖于常量初始化器的部分。这消除了在运行时计算它们的需要。
- 冗余节点消除:移除所有冗余节点而不改变图结构。目前支持以下此类优化:
- Identity 消除
- Slice 消除
- Unsqueeze 消除
- Dropout 消除
- 保留语义的节点融合:将多个节点融合/折叠成一个节点。例如,Conv Add 融合将 Add 算子折叠为 Conv 算子的偏置项。目前支持以下此类优化:
- Conv Add 融合
- Conv Mul 融合
- Conv BatchNorm 融合
- Relu Clip 融合
- Reshape 融合
扩展图优化
这些优化包括复杂的节点融合。它们在图划分后运行,并且仅应用于分配给 CPU、CUDA 或 ROCm 执行提供程序的节点。可用的扩展图优化如下:
优化项 | 执行提供程序 | 说明 |
---|---|---|
GEMM 激活融合 | CPU | |
Matmul Add 融合 | CPU | |
Conv 激活融合 | CPU | |
GELU 融合 | CPU, CUDA, ROCm | |
层归一化融合 | CPU, CUDA, ROCm | |
BERT 嵌入层融合 | CPU, CUDA, ROCm | 融合 BERT 嵌入层、层归一化和注意力掩码长度 |
注意力融合* | CPU, CUDA, ROCm | |
跳过层归一化融合 | CPU, CUDA, ROCm | 融合全连接层的偏置项、跳跃连接和层归一化 |
偏置 GELU 融合 | CPU, CUDA, ROCm | 融合全连接层的偏置项和 GELU 激活 |
GELU 近似* | CUDA, ROCm | 默认禁用。使用 kOrtSessionOptionsEnableGeluApproximation 启用 |
Approximations (click to expand)
Approximations (click to expand)
为了优化 BERT 的性能,CUDA 和 ROCm 执行提供程序的 GELU 近似和注意力融合中使用了近似。根据我们的评估,对准确率的影响可以忽略不计:BERT 模型在 SQuAD v1.1 上的 F1 分数几乎相同 (87.05 vs 87.03)。
布局优化
这些优化改变了适用节点的数据布局,以实现更高的性能提升。它们在图划分后运行,并且仅应用于分配给 CPU 执行提供程序的节点。可用的布局优化如下:
- NCHWc 优化器:通过使用 NCHWc 布局代替 NCHW 布局来优化图。
在线/离线模式
所有优化都可以在线或离线执行。在在线模式下,初始化推理会话时,我们会在执行模型推理之前应用所有启用的图优化。每次启动会话时都应用所有优化可能会增加模型启动时间(特别是对于复杂模型),这在生产场景中至关重要。这时离线模式就能带来很多好处。在离线模式下,执行图优化后,ONNX Runtime 会将生成的模型序列化到磁盘。随后,我们可以通过使用已优化的模型并禁用所有优化来减少启动时间。
注意:
- 在离线模式下运行时,请确保使用与目标推理机器完全相同的选项(例如,执行提供程序、优化级别)和硬件(例如,您不能在仅配备 CPU 的机器上运行为 GPU 执行提供程序预先优化的模型)。
- 启用布局优化后,离线模式只能在保存离线模型时与环境兼容的硬件上使用。例如,如果模型已针对 AVX2 优化了布局,则离线模型将需要支持 AVX2 的 CPU。
用法
级别
ONNX Runtime 定义了 GraphOptimizationLevel
枚举来确定启用上述哪些优化级别。选择一个级别会启用该级别的优化以及所有先前级别的优化。例如,启用扩展优化也会启用基本优化。这些级别与枚举的映射如下:
- GraphOptimizationLevel::ORT_DISABLE_ALL -> 禁用所有优化
- GraphOptimizationLevel::ORT_ENABLE_BASIC -> 启用基本优化
- GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 启用基本和扩展优化
- GraphOptimizationLevel::ORT_ENABLE_ALL -> 启用所有可用优化,包括布局优化
离线模式
要启用将优化后的模型序列化到磁盘,请设置 SessionOptions 选项 optimized_model_filepath
。
Python API 示例
import onnxruntime as rt
sess_options = rt.SessionOptions()
# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"
session = rt.InferenceSession("<model_path>", sess_options)
C API 示例
const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
OrtEnv* env;
g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
OrtSessionOptions* session_options;
g_ort->CreateSessionOptions(&session_options)
// Set graph optimization level
g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);
// To enable model serialization after graph optimization set this
const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);
OrtSession* session;
const ORTCHAR_T* model_path = ORT_TSTR("model_path");
g_ort->CreateSession(env, model_path, session_option, &session);
C# API 示例
SessionOptions so = new SessionOptions();
// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;
// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"
var session = new InferenceSession(modelPath, so);
C++ API 示例
Ort::SessionOptions session_options;
// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");
auto session_ = Ort::Session(env, "model_file_path", session_options);