ONNX Runtime 中的图优化

ONNX Runtime 提供了各种图优化来提高性能。图优化本质上是图级别的转换,范围从小型的图简化和节点消除到更复杂的节点融合和布局优化。

图优化根据其复杂性和功能分为几个类别(或级别)。它们可以在线离线执行。在在线模式下,优化在执行推理之前完成,而在离线模式下,运行时将优化的图保存到磁盘。 ONNX Runtime 提供了 Python、C#、C++ 和 C API 来启用不同的优化级别,并选择离线与在线模式。

下面我们提供有关优化级别、在线/离线模式以及控制它们的各种 API 的详细信息。

目录

图优化级别

图优化分为三个级别

  1. 基本
  2. 扩展
  3. 布局优化

一个级别中的优化是在应用前一级别的优化之后执行的(例如,扩展优化是在应用基本优化之后应用的)。

默认情况下,所有优化均已启用。

基本图优化

这些是语义保留的图重写,可删除冗余节点和冗余计算。它们在图分区之前运行,因此适用于所有执行提供程序。可用的基本图优化如下:

  • 常量折叠:静态计算图中仅依赖于常量初始值设定项的部分。这消除了在运行时计算它们的需要。

  • 冗余节点消除:删除所有冗余节点而不更改图结构。当前支持以下此类优化:
    • 恒等消除
    • 切片消除
    • Unsqueeze 消除
    • Dropout 消除
  • 语义保留节点融合:将多个节点融合/折叠为单个节点。例如,Conv Add 融合将 Add 运算符折叠为 Conv 运算符的偏置。当前支持以下此类优化:
    • Conv Add 融合
    • Conv Mul 融合
    • Conv BatchNorm 融合
    • Relu Clip 融合
    • Reshape 融合

扩展图优化

这些优化包括复杂的节点融合。它们在图分区之后运行,并且仅应用于分配给 CPU 或 CUDA 或 ROCm 执行提供程序的节点。可用的扩展图优化如下:

优化 执行提供程序 注释
GEMM 激活融合 CPU  
Matmul Add 融合 CPU  
Conv 激活融合 CPU  
GELU 融合 CPU、CUDA、ROCm  
层归一化融合 CPU、CUDA、ROCm  
BERT 嵌入层融合 CPU、CUDA、ROCm 融合 BERT 嵌入层、层归一化和注意力掩码长度
注意力融合* CPU、CUDA、ROCm  
跳过层归一化融合 CPU、CUDA、ROCm 融合全连接层的偏置、跳过连接和层归一化
Bias GELU 融合 CPU、CUDA、ROCm 融合全连接层的偏置和 GELU 激活
GELU 近似* CUDA、ROCm 默认禁用。使用 kOrtSessionOptionsEnableGeluApproximation 启用
Approximations (click to expand)

为了优化 BERT 的性能,GELU 近似和注意力融合在 CUDA 和 ROCm 执行提供程序中使用了近似。根据我们的评估,对精度的影响可以忽略不计:SQuAD v1.1 上 BERT 模型的 F1 分数几乎相同(87.05 vs 87.03)。

布局优化

这些优化更改适用节点的数据布局,以实现更高的性能改进。它们在图分区之后运行,并且仅应用于分配给 CPU 执行提供程序的节点。可用的布局优化如下:

  • NCHWc 优化器:通过使用 NCHWc 布局而不是 NCHW 布局来优化图。

在线/离线模式

所有优化都可以在线或离线执行。在在线模式下,当初始化推理会话时,我们还会在执行模型推理之前应用所有已启用的图优化。每次启动会话时都应用所有优化可能会增加模型启动时间的开销(特别是对于复杂模型),这在生产场景中可能至关重要。这就是离线模式可以带来很多好处的地方。在离线模式下,在执行图优化之后,ONNX Runtime 将生成的模型序列化到磁盘。随后,我们可以通过使用已优化的模型并禁用所有优化来减少启动时间。

注意:

  • 在离线模式下运行时,请确保使用与模型推理将在其上运行的目标机器完全相同的选项(例如,执行提供程序、优化级别)和硬件(例如,您不能在仅配备 CPU 的机器上运行为 GPU 执行提供程序预优化的模型)。
  • 当启用布局优化时,离线模式只能在与保存离线模型时环境兼容的硬件上使用。例如,如果模型具有针对 AVX2 优化的布局,则离线模型将需要支持 AVX2 的 CPU。

用法

级别

ONNX Runtime 定义了 GraphOptimizationLevel 枚举,以确定将启用上述哪个优化级别。选择一个级别将启用该级别的优化,以及所有先前级别的优化。例如,启用扩展优化也会启用基本优化。这些级别到枚举的映射如下:

  • GraphOptimizationLevel::ORT_DISABLE_ALL -> 禁用所有优化
  • GraphOptimizationLevel::ORT_ENABLE_BASIC -> 启用基本优化
  • GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 启用基本和扩展优化
  • GraphOptimizationLevel::ORT_ENABLE_ALL -> 启用所有可用的优化,包括布局优化

离线模式

要启用将优化模型序列化到磁盘,请设置 SessionOptions 选项 optimized_model_filepath

Python API 示例

import onnxruntime as rt

sess_options = rt.SessionOptions()

# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"

session = rt.InferenceSession("<model_path>", sess_options)

C API 示例

  const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
  OrtEnv* env;
  g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
  OrtSessionOptions* session_options;
  g_ort->CreateSessionOptions(&session_options)

  // Set graph optimization level
  g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);

  // To enable model serialization after graph optimization set this
  const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
  g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);

  OrtSession* session;
  const ORTCHAR_T* model_path = ORT_TSTR("model_path");
  g_ort->CreateSession(env, model_path, session_option, &session);

C# API 示例

SessionOptions so = new SessionOptions();

// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;

// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"

var session = new InferenceSession(modelPath, so);

C++ API 示例

Ort::SessionOptions session_options;

// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");

auto session_ = Ort::Session(env, "model_file_path", session_options);