使用 ONNX Runtime 加速 Phi-2、CodeLlama、Gemma 和其他 Gen AI 模型

作者:

Parinita Rahi, Sunghoon Choi, Yufeng Li, Kshama Pawar, Ashwini Khade, Ye Wang

2024 年 2 月 26 日

在速度和效率至上的快节奏环境中,ONNX Runtime (ORT) 使用户能够轻松地将生成式 AI 模型的强大功能集成到他们的应用程序和服务中,并通过改进的优化来提高推理速度并有效降低成本。这些优化包括最先进的融合和内核优化,以帮助提高模型性能。ONNX Runtime 1.17 最新版本改进了多种 Gen AI 模型的推理性能,包括 Phi-2、Mistral、CodeLlama、Orca-2 等。ONNX Runtime 是小型语言模型 (SLM) 从训练到推理的完整解决方案,与其他框架相比,显示出显著的加速。凭借对 float32、float16 和 int4 的支持,ONNX Runtime 的推理增强功能提供了最大的灵活性和性能。

在本博客中,我们将介绍最新的 GenAI 模型(如 Phi-2、Mistral、CodeLlama、SD-Turbo、SDXL-Turbo、Llama2 和 Orca-2)在训练和推理方面的显著优化加速。对于这些模型架构,与 PyTorch 和 Llama.cpp 等其他框架相比,ONNX Runtime 在各种批大小和提示长度下显著提高了性能。现在也可以使用 Olive 来实现使用 ONNX Runtime 的这些优化。

快速链接

Phi-2

Phi-2 是 Microsoft 开发的 27 亿参数 Transformer 模型。它是一个 SLM,具有出色的推理和语言理解能力。凭借其小巧的尺寸,Phi-2 为研究人员提供了一个绝佳的平台,他们可以探索各个方面,例如机制可解释性、安全改进以及针对不同任务的微调实验。

ONNX Runtime 1.17 引入了支持 Phi-2 模型的内核更改,包括针对 Phi-2 的 Attention、Multi-Head Attention、Grouped-Query Attention 和 RotaryEmbedding 的优化。具体来说,已添加对以下内容的支持:

  • Multi-Head Attention CPU 内核中的因果掩码
  • Attention 和 Rotary Embedding 内核中的 rotary_embedding_dim
  • Grouped-Query Attention 内核中的 bfloat16

支持基于 TorchDynamo 的 Phi-2 ONNX 导出,并且优化脚本构建在其之上。

对于 Phi-2 推理,对于所有提示长度,使用 float16 和 int4 量化的 ORT 比使用 float32、PyTorch 和 Llama.cpp 的 ORT 性能更好。

推理

ORT 使用 float16 的增益

优化的 CUDA 性能,用于提示吞吐量(即模型根据输入提示处理和生成响应的速率)比 PyTorch Compile 快高达 7.39 倍。我们还观察到,与 Llama.cpp 相比,ONNX Runtime 在更大的批大小和提示长度下明显更快。例如,对于批大小 = 16,提示长度 = 2048,速度快高达 13.08 倍

令牌生成吞吐量是生成的前 256 个令牌的平均吞吐量。使用 float16 的 ONNX Runtime 平均比 torch.compile 快6.6 倍最高可达 18.55 倍。它的性能也比 Llama.cpp 快高达 1.64 倍

Phi2 float16 prompt throughput comparison Phi2 float16 token generation throughput comparison

ORT 使用 int4 的增益

ORT 提供对 int4 量化的支持。使用 int4 量化的 ORT 与 PyTorch 相比,性能可以提高高达 20.48 倍。平均而言,它比 Llama.cpp 好 3.9 倍,对于大序列长度,速度快高达 13.42 倍。由于 GemV 的特殊内核,使用 int4 量化的 ONNX Runtime 通常在批大小为 1 时性能最佳。

Phi2 int4 prompt throughput comparison Phi2 int4 token generation throughput comparison
注意:torch.compile 在 4 位量化下效果不佳。此外,Llama.cpp 不使用 FlashAttention,其注意力实现对于大序列长度来说速度较慢。

  • Phi-2 基准测试在 1 个 A100 GPU 上完成(SKU:Standard_ND96amsr_A100_v4)。软件包:torch:2.3.0. dev20231221+cu121;pytorch-triton:2.2.0+e28a256d71;ort-nightly-gpu:1.17.0.dev20240118001;deepspeed:0.12
  • 批次是一组不同长度的输入句子;提示长度是指输入文本的大小或长度。

这是一个 使用 Olive 进行 Phi-2 优化的示例,它利用了本博客中强调的 ONNX Runtime 优化,使用了易于使用的硬件感知模型优化工具 Olive

训练

除了推理之外,ONNX Runtime 还为 Phi-2 和其他 LLM 提供训练加速。ORT 训练是 PyTorch 生态系统的一部分,可通过 torch-ort python 包作为 Azure Container for PyTorch (ACPT) 的一部分提供。它提供灵活且可扩展的硬件支持,相同的模型和 API 可与 NVIDIA 和 AMD GPU 一起使用。ORT 通过优化的内核和内存优化来加速训练,这在减少大型模型训练的端到端训练时间方面显示出显著的收益。这涉及到在模型中更改几行代码以使用 ORTModule API 对其进行包装。它还可以与 DeepSpeed 和 Megatron 等流行的加速库组合使用,以实现更快、更高效的训练。

Open AI 的 Triton 是一种特定领域的语言和编译器,用于编写高效的自定义深度学习原语。ORT 支持 Open AI Triton 集成 (ORT+Triton),其中所有逐元素运算符都转换为 Triton ops,并且 ORT 在 Triton 中创建自定义融合内核。

ORT 还执行稀疏性优化,以评估输入数据稀疏性并执行利用此稀疏性的图优化。这降低了计算 FLOP 要求并提高了性能。

基于低秩适配器 (LoRA) 的微调通过仅训练少量附加参数(适配器)同时冻结原始模型的权重来提高训练效率。这些适配器使模型适应特定任务。量化和 LoRA (QLoRA) 将量化与 LoRA 相结合,其中权重使用更少的位来表示,同时保持模型的性能和质量。ONNX Runtime 训练与 LoRA 和 QLoRA 结合使用,以提高 LLM 的内存效率和训练时间加速。LoRA 和 QLoRA 技术使 LLM 等超大型模型能够适应 GPU 内存,从而高效地完成训练。

使用 ORT 训练的 Phi-2 模型显示出优于 PyTorch Eager 模式和 torch.compile 的性能增益。Phi-2 是使用合成数据集和 Web 数据集的混合进行训练的。我们衡量了针对 ORT 和 ORT+Triton 模式的增益,并且增益随着批大小的增加而增加。该模型使用 DeepSpeed Stage-2 训练了 5 个 epoch,在 wikitext 数据集上增加了批大小。下图中总结了 V100 和 A100 的增益。

训练基准测试在 8 个 V100 上运行,并以迭代/秒为单位衡量吞吐量(越高越好)

Phi2 training throughput comparison

以下训练基准测试在 2 个 A100 上运行,并以迭代/秒为单位衡量吞吐量(越高越好)

Phi2 在 2 个 A100 上的训练基准测试 注意:使用了 PyTorch Stable 2.2.0 和 ONNXRuntime Training: Stable 1.17.0 版本。

Mistral

推理

Mistral7B 是一个预训练的生成文本 LLM,具有 70 亿个参数。ONNX Runtime 显著提高了 Mistral 在 float16 和 int4 模型下的推理性能。使用 float16,ONNX Runtime 比 Llama.cpp 快高达 9.46 倍。对于批大小为 1 的情况,令牌生成吞吐量通过 int4 量化得到了显著提高,比 PyTorch Eager 快高达 18.25 倍

Mistral float16 prompt throughput comparison Mistral float16 token generation throughput comparison Mistral int4 prompt throughput comparison Mistral int4 token generation throughput comparison

您现在可以在 Huggingface 上访问优化的 Mistral 模型,点击此处

训练

与 Phi-2 类似,Mistral 也受益于使用 ORT 进行的训练加速。我们使用以下配置训练了 Mistral-7B,以查看 ORT 的增益,包括与 LoRA 和 QLoRA 组合使用时的增益。该模型使用 DeepSpeed Stage-2 训练了 5 个 epoch,在 wikitext 数据集上批大小为 1。

Mistral training benchmarks

CodeLlama

Codellama-70B 是一个基于 Llama-2 平台开发的面向编程的模型。此模型可以生成代码并以自然语言生成围绕代码的讨论。由于 CodeLlama-70B 是一个微调的 Llama 模型,因此可以直接应用现有的优化。我们将 4 位量化的 ONNX 模型与 PyTorch Eager 和 Llama.cpp 进行了比较。对于提示吞吐量,对于所有批大小,ONNX Runtime 比 PyTorch Eager 快至少 1.4 倍。对于任何批大小,ONNX Runtime 生成令牌的平均速度比 PyTorch Eager 高 3.4 倍,对于批大小为 1,比 Llama.cpp 高 1.5 倍

CodeLLama int4 prompt throughput comparison CodeLLama int4 token generation throughput comparison

SD-Turbo 和 SDXL-Turbo

ONNX Runtime 在与 SD TurboSDXL Turbo 一起使用时,可提供推理性能优势,并且还使模型可以在 Python 以外的语言(如 C# 和 Java)中访问。对于评估的所有(批大小,步数)组合,ONNX Runtime 实现了比 PyTorch 更高的吞吐量,SDXL Turbo 模型的吞吐量提高了高达 229%,SD Turbo 模型的吞吐量提高了 120%。ONNX Runtime CUDA 尤其擅长处理动态形状,但对于静态形状,它也显示出比 PyTorch 显著的优势。

Stable Diffusion XL Turbo Speedup

要详细了解如何使用 ONNX Runtime 加速 SD-Turbo 和 SDXL-Turbo 推理,请查看我们最近与 Hugging Face 合作撰写的博客

Llama-2

我们发布了一篇单独的博客,介绍了 ORT 在 Llama-2 推理方面的改进,点击此处查看。此外,Llama-2-7B 和 Llama-2-13B 在 ORT 训练方面表现出良好的增益,尤其是在与 LoRA 和 QLoRA 结合使用时。这些脚本可以用作示例,以使用 Optimum 和 ORT 微调 Llama-2。以下数字是 Llama-2 模型使用 ORT 和 DeepSpeed Stage-2 训练 5 个 epoch,在 wikitext 数据集上批大小为 1 的结果。

Llama2 training benchmarks

Orca-2

推理

Orca-2 是一个仅供研究的系统,它在诸如推理用户提供的数据、理解文本、解决数学问题和总结文本等任务中给出一次性答案。Orca-2 有两个版本(70 亿和 130 亿参数);它们都是通过在定制的高质量人工数据上微调相应的 Llama-2 基础模型而制成的。ONNX Runtime 有助于优化 Orca-2 推理,以使用图形融合和内核优化,例如 Llama-2 的优化。

ORT 使用 int4 的增益

Orca-2-7B int4 量化性能比较表明,提示吞吐量性能提高了高达 26 倍,令牌生成吞吐量提高了高达 16.5 倍(与 PyTorch 相比)。与 Llama.cpp 相比,提示吞吐量提高了 4.75 倍以上,令牌生成吞吐量提高了 3.64 倍

Orca2 7b int4 prompt throughput comparison Orca2 7b int4 token generation throughput comparison Orca2 13b int4 prompt throughput comparison Orca2 13b int4 token generation throughput comparison

Orca-2 7b 与 ONNX Runtime float16 的性能比较也显示出提示和令牌生成吞吐量的显著增益。

Orca2 7b float16 prompt throughput comparison Orca2 7b float16 token generation throughput comparison Orca2 13b float16 prompt throughput comparison Orca2 13b float16 token generation throughput comparison

Orca-2 基准测试在 1 个 A100 GPU 上完成,SKU:Standard_ND96amsr_A100_v4,软件包 torch 2.2.0、triton 2.2.0、onnxruntime-gpu 1.17.0、deepspeed 0.13.2、llama.cpp - commit 594fca3fefe27b8e95cfb1656eb0e160ad15a793、transformers 4.37.2

训练

Orca-2-7B 也受益于使用 ORT 进行的训练加速。我们针对 512 的序列长度训练了 Orca-2-7B 模型,启用了 LoRA 和稀疏性优化,并看到了良好的性能增益。以下数字是 Orca-2-7B 模型使用 ORT 和 DeepSpeed Stage-2 训练 5 个 epoch,在 wikitext 数据集上批大小为 1 的结果。

Orca2 训练基准测试 使用 ACPT 镜像:nightly-ubuntu2004-cu118-py38-torch230dev:20240131

Gemma

Gemma 是一个轻量级、开放模型的系列,它基于 Google 用于创建 Gemini 模型的研究和技术构建。它有两种尺寸:2B 和 7B。每种尺寸都发布了预训练和指令调优变体。ONNX Runtime 可用于优化和高效运行任何开源模型。我们针对 Gemma-2B 模型进行了基准测试,使用 float16 的 ONNX Runtime 比 PyTorch Compile 快高达 7.47 倍,比 Llama.cpp 快高达 3.47 倍。使用 int4 量化的 ORT 比 PyTorch Eager 快高达 19.81 倍,比 Llama.cpp 快 2.62 倍

Gemma2b int4 token generation throughput comparison Gemma2b token generation throughput comparison

结论

总之,ONNX Runtime (ORT) 为多种模型提供了显著的性能改进,包括 Phi-2、Mistral、CodeLlama、SDXL-Turbo、Llama-2、Orca-2 和 Gemma。ORT 提供最先进的融合和内核优化,包括对 float16 和 int4 量化的支持,从而实现更快的推理速度和更低的成本。在提示和令牌生成吞吐量方面,ORT 优于 PyTorch 和 Llama.cpp 等其他框架。ORT 在 LLM 训练方面也显示出显著优势,批大小越大,增益越大,并且与最先进的技术很好地结合,从而实现高效的大型模型训练。