ONNX Runtime 中的图优化

ONNX Runtime 提供了各种图优化来提高性能。图优化本质上是图级别的转换,范围从小型图简化和节点消除到更复杂的节点融合和布局优化。

图优化根据其复杂性和功能分为几个类别(或级别)。它们可以以在线离线模式执行。在在线模式下,优化在执行推理之前完成;而在离线模式下,运行时将优化后的图保存到磁盘。ONNX Runtime 提供了 Python、C#、C++ 和 C API,可以启用不同的优化级别,并选择离线或在线模式。

下面我们详细介绍优化级别、在线/离线模式以及控制它们的各种 API。

目录

图优化级别

图优化分为三个级别

  1. 基本
  2. 扩展
  3. 布局优化

一个级别所属的优化在前一个级别的优化应用后执行(例如,扩展优化在基本优化应用后执行)。

所有优化默认启用。

基本图优化

这些是保留语义的图重写,用于移除冗余节点和冗余计算。它们在图划分之前运行,因此适用于所有执行提供程序。可用的基本图优化如下:

  • 常量折叠:静态计算图中仅依赖于常量初始化器的部分。这消除了在运行时计算它们的需要。

  • 冗余节点消除:移除所有冗余节点而不改变图结构。目前支持以下此类优化:
    • Identity 消除
    • Slice 消除
    • Unsqueeze 消除
    • Dropout 消除
  • 保留语义的节点融合:将多个节点融合/折叠成一个节点。例如,Conv Add 融合将 Add 算子折叠为 Conv 算子的偏置项。目前支持以下此类优化:
    • Conv Add 融合
    • Conv Mul 融合
    • Conv BatchNorm 融合
    • Relu Clip 融合
    • Reshape 融合

扩展图优化

这些优化包括复杂的节点融合。它们在图划分后运行,并且仅应用于分配给 CPU、CUDA 或 ROCm 执行提供程序的节点。可用的扩展图优化如下:

优化项 执行提供程序 说明
GEMM 激活融合 CPU  
Matmul Add 融合 CPU  
Conv 激活融合 CPU  
GELU 融合 CPU, CUDA, ROCm  
层归一化融合 CPU, CUDA, ROCm  
BERT 嵌入层融合 CPU, CUDA, ROCm 融合 BERT 嵌入层、层归一化和注意力掩码长度
注意力融合* CPU, CUDA, ROCm  
跳过层归一化融合 CPU, CUDA, ROCm 融合全连接层的偏置项、跳跃连接和层归一化
偏置 GELU 融合 CPU, CUDA, ROCm 融合全连接层的偏置项和 GELU 激活
GELU 近似* CUDA, ROCm 默认禁用。使用 kOrtSessionOptionsEnableGeluApproximation 启用
Approximations (click to expand)

为了优化 BERT 的性能,CUDA 和 ROCm 执行提供程序的 GELU 近似和注意力融合中使用了近似。根据我们的评估,对准确率的影响可以忽略不计:BERT 模型在 SQuAD v1.1 上的 F1 分数几乎相同 (87.05 vs 87.03)。

布局优化

这些优化改变了适用节点的数据布局,以实现更高的性能提升。它们在图划分后运行,并且仅应用于分配给 CPU 执行提供程序的节点。可用的布局优化如下:

  • NCHWc 优化器:通过使用 NCHWc 布局代替 NCHW 布局来优化图。

在线/离线模式

所有优化都可以在线或离线执行。在在线模式下,初始化推理会话时,我们会在执行模型推理之前应用所有启用的图优化。每次启动会话时都应用所有优化可能会增加模型启动时间(特别是对于复杂模型),这在生产场景中至关重要。这时离线模式就能带来很多好处。在离线模式下,执行图优化后,ONNX Runtime 会将生成的模型序列化到磁盘。随后,我们可以通过使用已优化的模型并禁用所有优化来减少启动时间。

注意:

  • 在离线模式下运行时,请确保使用与目标推理机器完全相同的选项(例如,执行提供程序、优化级别)和硬件(例如,您不能在仅配备 CPU 的机器上运行为 GPU 执行提供程序预先优化的模型)。
  • 启用布局优化后,离线模式只能在保存离线模型时与环境兼容的硬件上使用。例如,如果模型已针对 AVX2 优化了布局,则离线模型将需要支持 AVX2 的 CPU。

用法

级别

ONNX Runtime 定义了 GraphOptimizationLevel 枚举来确定启用上述哪些优化级别。选择一个级别会启用该级别的优化以及所有先前级别的优化。例如,启用扩展优化也会启用基本优化。这些级别与枚举的映射如下:

  • GraphOptimizationLevel::ORT_DISABLE_ALL -> 禁用所有优化
  • GraphOptimizationLevel::ORT_ENABLE_BASIC -> 启用基本优化
  • GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 启用基本和扩展优化
  • GraphOptimizationLevel::ORT_ENABLE_ALL -> 启用所有可用优化,包括布局优化

离线模式

要启用将优化后的模型序列化到磁盘,请设置 SessionOptions 选项 optimized_model_filepath

Python API 示例

import onnxruntime as rt

sess_options = rt.SessionOptions()

# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"

session = rt.InferenceSession("<model_path>", sess_options)

C API 示例

  const OrtApi* Ort::g_api = OrtGetApi(ORT_API_VERSION);
  OrtEnv* env;
  g_ort->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env);
  OrtSessionOptions* session_options;
  g_ort->CreateSessionOptions(&session_options)

  // Set graph optimization level
  g_ort->SetSessionGraphOptimizationLevel(session_options, ORT_ENABLE_EXTENDED);

  // To enable model serialization after graph optimization set this
  const ORTCHAR_T* optimized_model_path = ORT_TSTR("optimized_model_path");
  g_ort->SetOptimizedModelFilePath(session_options, optimized_model_path);

  OrtSession* session;
  const ORTCHAR_T* model_path = ORT_TSTR("model_path");
  g_ort->CreateSession(env, model_path, session_option, &session);

C# API 示例

SessionOptions so = new SessionOptions();

// Set graph optimization level
so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED;

// To enable model serialization after graph optimization set this
so.OptimizedModelFilePath = "model_output_path\optimized_model.onnx"

var session = new InferenceSession(modelPath, so);

C++ API 示例

Ort::SessionOptions session_options;

// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");

auto session_ = Ort::Session(env, "model_file_path", session_options);