ONNX Runtime 中的图优化

ONNX Runtime 提供了各种图优化以提高性能。图优化本质上是图级别的转换,范围从小的图简化和节点消除到更复杂的节点融合和布局优化。

图优化根据其复杂性和功能分为几个类别(或级别)。它们可以在线离线执行。在线模式下,优化在执行推理之前完成;而在离线模式下,运行时将优化后的图保存到磁盘。ONNX Runtime 提供了 Python、C#、C++ 和 C API,以启用不同的优化级别并选择离线或在线模式。

下面我们提供有关优化级别、在线/离线模式以及控制它们的各种 API 的详细信息。

目录

图优化级别

图优化分为三个级别

  1. 基本
  2. 扩展
  3. 布局优化

属于一个级别的优化在应用前一个级别的优化之后执行(例如,在应用基本优化之后应用扩展优化)。

所有优化默认启用。

基本图优化

这些是保留语义的图重写,用于移除冗余节点和冗余计算。它们在图分区之前运行,因此适用于所有执行提供者。可用的基本图优化如下:

  • 常量折叠:静态计算图中仅依赖于常量初始化器的部分。这消除了在运行时计算它们的需要。

  • 冗余节点消除:在不改变图结构的情况下移除所有冗余节点。目前支持的此类优化如下:
    • Identity 消除
    • Slice 消除
    • Unsqueeze 消除
    • Dropout 消除
  • 语义保留节点融合:将多个节点融合/折叠成一个节点。例如,Conv Add 融合将 Add 操作符折叠为 Conv 操作符的偏置。目前支持的此类优化如下:
    • Conv Add 融合
    • Conv Mul 融合
    • Conv BatchNorm 融合
    • Relu Clip 融合
    • Reshape 融合

扩展图优化

这些优化包括复杂的节点融合。它们在图分区之后运行,并且仅适用于分配给 CPU、CUDA 或 ROCm 执行提供者的节点。可用的扩展图优化如下:

优化项 执行提供者 备注
GEMM 激活融合 CPU  
Matmul Add 融合 CPU  
Conv 激活融合 CPU  
GELU 融合 CPU, CUDA, ROCm  
层归一化融合 CPU, CUDA, ROCm  
BERT 嵌入层融合 CPU, CUDA, ROCm 融合 BERT 嵌入层、层归一化和注意力掩码长度
注意力融合* CPU, CUDA, ROCm  
跳跃层归一化融合 CPU, CUDA, ROCm 融合全连接层的偏置、跳跃连接和层归一化
偏置 GELU 融合 CPU, CUDA, ROCm 融合全连接层的偏置和 GELU 激活
GELU 近似* CUDA, ROCm 默认禁用。通过 kOrtSessionOptionsEnableGeluApproximation 启用。
Approximations (click to expand)

为优化 BERT 的性能,在 GELU 近似和注意力融合中,对 CUDA 和 ROCm 执行提供者使用了近似方法。根据我们的评估,对精度的影响可以忽略不计:SQuAD v1.1 上 BERT 模型的 F1 分数几乎相同 (87.05 vs 87.03)。

布局优化

这些优化改变了适用节点的数据布局,以实现更高的性能改进。它们在图分区之后运行,并且仅适用于分配给 CPU 执行提供者的节点。可用的布局优化如下:

  • NCHWc 优化器:通过使用 NCHWc 布局而不是 NCHW 布局来优化图。

在线/离线模式

所有优化既可以在线执行,也可以离线执行。在线模式下,在初始化推理会话时,我们会在执行模型推理之前应用所有已启用的图优化。每次启动会话时都应用所有优化可能会增加模型启动时间开销(特别是对于复杂模型),这在生产场景中可能至关重要。这就是离线模式可以带来巨大好处的地方。在离线模式下,执行图优化后,ONNX Runtime 会将结果模型序列化到磁盘。随后,我们可以通过使用已优化模型并禁用所有优化来减少启动时间。

备注:

  • 在离线模式下运行时,请确保使用与模型推理将运行的目标机器完全相同的选项(例如,执行提供者、优化级别)和硬件(例如,您不能在仅配备 CPU 的机器上运行为 GPU 执行提供者预优化的模型)。
  • 当启用布局优化时,离线模式只能在与保存离线模型时的环境兼容的硬件上使用。例如,如果模型针对 AVX2 进行了布局优化,则离线模型将需要支持 AVX2 的 CPU。

用法

级别

ONNX Runtime 定义了 GraphOptimizationLevel 枚举,用于确定启用上述哪些优化级别。选择一个级别将启用该级别的优化以及所有前置级别的优化。例如,启用扩展优化也会启用基本优化。这些级别到枚举的映射如下:

  • GraphOptimizationLevel::ORT_DISABLE_ALL -> 禁用所有优化
  • GraphOptimizationLevel::ORT_ENABLE_BASIC -> 启用基本优化
  • GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 启用基本和扩展优化
  • GraphOptimizationLevel::ORT_ENABLE_ALL -> 启用所有可用优化,包括布局优化

离线模式

要启用将优化后的模型序列化到磁盘,请设置 SessionOptions 选项 optimized_model_filepath

Python API 示例

import onnxruntime as rt

sess_options = rt.SessionOptions()

# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"

session = rt.InferenceSession("<model_path>", sess_options)

C API 示例

  const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
  OrtEnv* env;
  g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
  OrtSessionOptions* session_options;
  g_ort->CreateSessionOptions(&session_options)

  // Set graph optimization level
  g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);

  // To enable model serialization after graph optimization set this
  const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
  g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);

  OrtSession* session;
  const ORTCHAR_T* model_path = ORT_TSTR("model_path");
  g_ort->CreateSession(env, model_path, session_options, &session);

C# API 示例

SessionOptions so = new SessionOptions();

// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;

// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"

var session = new InferenceSession(modelPath, so);

C++ API 示例

Ort::SessionOptions session_options;

// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");

auto session_ = Ort::Session(env, "model_file_path", session_options);