使用 ONNX Runtime 加速 Phi-2、CodeLlama、Gemma 及其他生成式 AI 模型

作者

Parinita Rahi, Sunghoon Choi, Yufeng Li, Kshama Pawar, Ashwini Khade, Ye Wang

2024年2月26日

在快速发展的时代,速度和效率至关重要。通过 ONNX Runtime (ORT),用户可以轻松地将生成式 AI 模型的能力集成到其应用和服务中,并且通过改进的优化,实现更快的推理速度并有效降低成本。这些优化包括最先进的融合和内核优化,有助于提高模型性能。最近发布的 ONNX Runtime 1.17 版本 提升了包括 Phi-2、Mistral、CodeLlama、Orca-2 等多种生成式 AI 模型的推理性能。ONNX Runtime 是小型语言模型 (SLM) 从训练到推理的完整解决方案,与其他框架相比,显示出显著的加速。ONNX Runtime 支持 float32、float16 和 int4,其推理增强功能提供了最大的灵活性和性能。

在本篇博客中,我们将介绍 ONNX Runtime 对 Phi-2、Mistral、CodeLlama、SD-Turbo、SDXL-Turbo、Llama2 和 Orca-2 等最新生成式 AI 模型在训练和推理方面的显著优化加速。对于这些模型架构,ONNX Runtime 在各种批量大小和提示长度下,与 PyTorch 和 Llama.cpp 等其他框架相比,均显著提升了性能。使用 ONNX Runtime 进行的这些优化现在也可以通过 Olive 获得。

快速链接

Phi-2

Phi-2 是微软开发的具有 27 亿参数的 Transformer 模型。它是一个小型语言模型 (SLM),展现出卓越的推理和语言理解能力。由于其小巧的尺寸,Phi-2 是研究人员进行各种探索的绝佳平台,例如机制可解释性、安全性改进以及在不同任务上的微调实验。

ONNX Runtime 1.17 引入了内核变更以支持 Phi-2 模型,包括对 Phi-2 的 Attention、Multi-Head Attention、Grouped-Query Attention 和 RotaryEmbedding 的优化。具体而言,已添加了对以下方面的支持:

  • Multi-Head Attention CPU 内核中的因果掩码
  • Attention 和 Rotary Embedding 内核中的 rotary_embedding_dim
  • Grouped-Query Attention 内核中的 bfloat16

支持基于 TorchDynamo 的 Phi-2 ONNX 导出,并且优化脚本在此基础上构建。

对于 Phi-2 推理,使用 float16 和 int4 量化的 ORT 在所有提示长度下,均优于使用 float32 的 ORT、PyTorch 和 Llama.cpp。

推理

ORT 在 float16 下的性能提升

优化后的 CUDA 在提示吞吐量(即模型根据输入提示处理和生成响应的速率)方面,比 PyTorch Compile 快 高达 7.39 倍。我们还观察到,ONNX Runtime 在更大的批量大小和提示长度下,比 Llama.cpp 显著更快。例如,在批量大小 =16、提示长度 =2048 时,它快 高达 13.08 倍

令牌生成吞吐量是生成的前 256 个令牌的平均吞吐量。使用 float16 的 ONNX Runtime 平均比 torch.compile 快 6.6 倍,最高可达 18.55 倍。它也比 Llama.cpp 快 高达 1.64 倍

Phi2 float16 prompt throughput comparison Phi2 float16 token generation throughput comparison

ORT 在 int4 下的性能提升

ORT 支持 int4 量化。使用 int4 量化的 ORT 与 PyTorch 相比,性能可提升 高达 20.48 倍。与 Llama.cpp 相比,平均提升 3.9 倍,对于大序列长度,快 高达 13.42 倍。由于采用了特殊的 GemV 内核,使用 int4 量化的 ONNX Runtime 通常在批量大小为 1 时性能最佳。

Phi2 int4 prompt throughput comparison Phi2 int4 token generation throughput comparison
注意:torch.compile 在 4 位量化下效果不佳。此外,Llama.cpp 不使用 FlashAttention,其 attention 实现对于大序列长度较慢。

  • Phi-2 基准测试在 1 个 A100 GPU (SKU: Standard_ND96amsr_A100_v4) 上完成。软件包:torch: 2.3.0. dev20231221+cu121; pytorch-triton: 2.2.0+e28a256d71;ort-nightly-gpu: 1.17.0.dev20240118001;deepspeed: 0.12
  • 批量(Batch)是指一组长度不同的输入句子;提示长度(Prompt length)是指输入文本的大小或长度。

这是一个使用 Olive 对 Phi-2 进行优化的示例,该示例利用本博客中强调的 ONNX Runtime 优化,并使用易于使用的硬件感知模型优化工具 Olive

训练

除了推理,ONNX Runtime 还为 Phi-2 和其他 LLM 提供了训练加速。ORT 训练是 PyTorch 生态系统的一部分,可通过 torch-ort python 包作为 Azure Container for PyTorch (ACPT) 的一部分获取。它提供灵活且可扩展的硬件支持,相同的模型和 API 可在 NVIDIA 和 AMD GPU 上工作。ORT 通过优化的内核和内存优化加速训练,在减少大型模型训练的端到端训练时间方面显示出显著的提升。这仅需修改模型中的几行代码,使用 ORTModule API 进行封装即可。它还可以与 DeepSpeed 和 Megatron 等流行的加速库组合使用,以实现更快、更高效的训练。

Open AI 的 Triton 是一种领域特定语言和编译器,用于编写高效的自定义深度学习原语。ORT 支持 Open AI Triton 集成(ORT+Triton),将所有逐元素运算符转换为 Triton op,并在 Triton 中创建自定义融合内核。

ORT 还执行稀疏性优化,评估输入数据的稀疏性并利用这种稀疏性进行图优化。这减少了计算 FLOP 需求并提高了性能。

基于低秩适配器 (LoRA) 的微调通过仅训练少量附加参数(即适配器)同时冻结原始模型权重,使得训练更加高效。这些适配器使模型适应特定任务。量化和 LoRA (QLoRA) 将量化与 LoRA 结合,其中权重使用更少的位表示,同时保持模型的性能和质量。ONNX Runtime 训练可以与 LoRA 和 QLoRA 结合,从而提高 LLM 的内存效率和训练时间。LoRA 和 QLoRA 技术使得 LLM 等非常大的模型能够适应 GPU 内存,从而高效地完成训练。

使用 ORT 训练的 Phi-2 模型与 PyTorch Eager 模式和 torch.compile 相比显示出性能提升。Phi-2 使用合成数据集和网络数据集的混合进行训练。我们衡量了 ORT 和 ORT+Triton 模式下的性能提升,且提升随着批量大小的增加而增加。该模型使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch,批量大小逐渐增加。V100 和 A100 上的提升总结在下面的图表中。

训练基准测试在 8 个 V100 上运行,衡量吞吐量(迭代/秒,越高越好)

Phi2 training throughput comparison

以下训练基准测试在 2 个 A100 上运行,衡量吞吐量(迭代/秒,越高越好)

Phi2 在 2 个 A100 上的训练基准测试 注:使用了 PyTorch Stable 2.2.0 和 ONNXRuntime Training Stable 1.17.0 版本。

Mistral

推理

Mistral7B 是一个预训练的生成式文本 LLM,拥有 70 亿参数。ONNX Runtime 对 float16 和 int4 模型的 Mistral 均显著提高了推理性能。对于 float16,ONNX Runtime 比 Llama.cpp 快 高达 9.46 倍。使用 int4 量化时,批量大小为 1 的令牌生成吞吐量显著提高,比 PyTorch Eager 快 高达 18.25 倍

Mistral float16 prompt throughput comparison Mistral float16 token generation throughput comparison Mistral int4 prompt throughput comparison Mistral int4 token generation throughput comparison

您现在可以在 Huggingface 上获取优化后的 Mistral 模型,点击此处

训练

与 Phi-2 类似,Mistral 也受益于使用 ORT 进行训练加速。我们使用以下配置训练了 Mistral-7B,以观察使用 ORT 时的性能提升,包括与 LoRA 和 QLoRA 组合时。该模型使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch,批量大小为 1。

Mistral training benchmarks

CodeLlama

Codellama-70B 是一个基于 Llama-2 平台开发的专注于编程的模型。该模型可以生成代码,并用自然语言讨论代码。由于 CodeLlama-70B 是一个经过微调的 Llama 模型,现有的优化可以直接应用。我们将一个 4 位量化的 ONNX 模型与 PyTorch Eager 和 Llama.cpp 进行了比较。对于提示吞吐量,ONNX Runtime 在所有批量大小下都比 PyTorch Eager 快至少 1.4 倍。ONNX Runtime 生成令牌的平均速度比任何批量大小的 PyTorch Eager 高 3.4 倍,比批量大小为 1 的 Llama.cpp 高 1.5 倍。

CodeLLama int4 prompt throughput comparison CodeLLama int4 token generation throughput comparison

SD-Turbo 和 SDXL-Turbo

ONNX Runtime 在与 SD TurboSDXL Turbo 一起使用时,提供了推理性能优势,并且使得这些模型可以通过 Python 以外的语言(如 C# 和 Java)进行访问。在所有评估的(批量大小,步数)组合下,ONNX Runtime 的吞吐量均高于 PyTorch,其中 SDXL Turbo 模型吞吐量提升 高达 229%,SD Turbo 模型提升 120%。ONNX Runtime CUDA 特别擅长处理动态形状,但在静态形状方面也比 PyTorch 具有显著优势。

Stable Diffusion XL Turbo Speedup

要了解更多关于使用 ONNX Runtime 加速 SD-Turbo 和 SDXL-Turbo 推理的信息,请查看我们最近与 Hugging Face 发布的博客,点击此处

Llama-2

我们在此处发布了一篇独立的博客,介绍使用 ORT 对 Llama-2 推理进行的改进。点击此处。此外,Llama-2-7B 和 Llama-2-13B 在使用 ORT 进行训练时显示出良好的提升,尤其是与 LoRA 和 QLoRA 结合时。这些脚本 可以作为使用 Optimum 和 ORT 微调 Llama-2 的示例。下面的数字是使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch、批量大小为 1 的 Llama-2 模型在使用 ORT 进行训练时的结果。

Llama2 training benchmarks

Orca-2

推理

Orca-2 是一个仅用于研究的系统,可在用户提供数据推理、理解文本、解决数学问题和文本摘要等任务中提供一次性答案。Orca-2 有两个版本(70 亿和 130 亿参数);它们都是通过在定制的高质量人工数据上微调相应的 Llama-2 基础模型而构建的。ONNX Runtime 通过使用图形融合和内核优化(类似于 Llama-2 的优化)帮助优化 Orca-2 的推理。

ORT 在 int4 下的性能提升

Orca-2-7B int4 量化性能比较显示,提示吞吐量性能提升 高达 26 倍,令牌生成吞吐量比 PyTorch 提升 高达 16.5 倍。与 Llama.cpp 相比,提示吞吐量也提升了 4.75 倍以上,令牌生成吞吐量提升了 3.64 倍以上。

Orca2 7b int4 prompt throughput comparison Orca2 7b int4 token generation throughput comparison Orca2 13b int4 prompt throughput comparison Orca2 13b int4 token generation throughput comparison

使用 ONNX runtime float16 的 Orca-2 7b 性能比较也显示出在提示和令牌生成吞吐量方面的显著提升。

Orca2 7b float16 prompt throughput comparison Orca2 7b float16 token generation throughput comparison Orca2 13b float16 prompt throughput comparison Orca2 13b float16 token generation throughput comparison

Orca-2 基准测试在 1 个 A100 GPU (SKU: Standard_ND96amsr_A100_v4) 上完成,软件包:torch 2.2.0, triton 2.2.0, onnxruntime-gpu 1.17.0, deepspeed 0.13.2, llama.cpp - commit 594fca3fefe27b8e95cfb1656eb0e160ad15a793, transformers 4.37.2

训练

Orca-2-7B 也受益于使用 ORT 进行训练加速。我们使用 LoRA 并启用了稀疏性优化,对 Orca-2-7B 模型以 512 的序列长度进行了训练,看到了良好的性能提升。下面的数字是使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch、批量大小为 1 的 Orca-2-7B 模型在使用 ORT 进行训练时的结果。

Orca2 训练基准测试 使用 ACPT 镜像:nightly-ubuntu2004-cu118-py38-torch230dev:20240131

Gemma

Gemma 是一个轻量级、开放的模型系列,基于 Google 用于创建 Gemini 模型的研究和技术构建而成。它提供两种尺寸:2B 和 7B。每种尺寸都发布了预训练和指令微调版本。ONNX Runtime 可用于优化并高效运行任何开源模型。我们针对 Gemma-2B 模型进行了基准测试,使用 float16 的 ONNX Runtime 比 PyTorch Compile 快 高达 7.47 倍,比 Llama.cpp 快 高达 3.47 倍。使用 int4 量化的 ORT 比 PyTorch Eager 快 高达 19.81 倍,比 Llama.cpp 快 2.62 倍。

Gemma2b int4 token generation throughput comparison Gemma2b token generation throughput comparison

结论

总之,ONNX Runtime (ORT) 为包括 Phi-2、Mistral、CodeLlama、SDXL-Turbo、Llama-2、Orca-2 和 Gemma 在内的多种模型提供了显著的性能提升。ORT 提供最先进的融合和内核优化,包括支持 float16 和 int4 量化,从而实现更快的推理速度和更低的成本。在提示和令牌生成吞吐量方面,ORT 优于 PyTorch 和 Llama.cpp 等其他框架。ORT 在训练 LLM 方面也显示出显著优势,批量大小越大提升越大,并且与最先进的技术很好地结合,实现高效的大型模型训练。