使用 ONNX Runtime 加速 Phi-2、CodeLlama、Gemma 及其他生成式 AI 模型
作者
Parinita Rahi, Sunghoon Choi, Yufeng Li, Kshama Pawar, Ashwini Khade, Ye Wang2024年2月26日
在快速发展的时代,速度和效率至关重要。通过 ONNX Runtime (ORT),用户可以轻松地将生成式 AI 模型的能力集成到其应用和服务中,并且通过改进的优化,实现更快的推理速度并有效降低成本。这些优化包括最先进的融合和内核优化,有助于提高模型性能。最近发布的 ONNX Runtime 1.17 版本 提升了包括 Phi-2、Mistral、CodeLlama、Orca-2 等多种生成式 AI 模型的推理性能。ONNX Runtime 是小型语言模型 (SLM) 从训练到推理的完整解决方案,与其他框架相比,显示出显著的加速。ONNX Runtime 支持 float32、float16 和 int4,其推理增强功能提供了最大的灵活性和性能。
在本篇博客中,我们将介绍 ONNX Runtime 对 Phi-2、Mistral、CodeLlama、SD-Turbo、SDXL-Turbo、Llama2 和 Orca-2 等最新生成式 AI 模型在训练和推理方面的显著优化加速。对于这些模型架构,ONNX Runtime 在各种批量大小和提示长度下,与 PyTorch 和 Llama.cpp 等其他框架相比,均显著提升了性能。使用 ONNX Runtime 进行的这些优化现在也可以通过 Olive 获得。
快速链接
Phi-2
Phi-2 是微软开发的具有 27 亿参数的 Transformer 模型。它是一个小型语言模型 (SLM),展现出卓越的推理和语言理解能力。由于其小巧的尺寸,Phi-2 是研究人员进行各种探索的绝佳平台,例如机制可解释性、安全性改进以及在不同任务上的微调实验。
ONNX Runtime 1.17 引入了内核变更以支持 Phi-2 模型,包括对 Phi-2 的 Attention、Multi-Head Attention、Grouped-Query Attention 和 RotaryEmbedding 的优化。具体而言,已添加了对以下方面的支持:
- Multi-Head Attention CPU 内核中的因果掩码
- Attention 和 Rotary Embedding 内核中的 rotary_embedding_dim
- Grouped-Query Attention 内核中的 bfloat16
支持基于 TorchDynamo 的 Phi-2 ONNX 导出,并且优化脚本在此基础上构建。
对于 Phi-2 推理,使用 float16 和 int4 量化的 ORT 在所有提示长度下,均优于使用 float32 的 ORT、PyTorch 和 Llama.cpp。
推理
ORT 在 float16 下的性能提升
优化后的 CUDA 在提示吞吐量(即模型根据输入提示处理和生成响应的速率)方面,比 PyTorch Compile 快 高达 7.39 倍。我们还观察到,ONNX Runtime 在更大的批量大小和提示长度下,比 Llama.cpp 显著更快。例如,在批量大小 =16、提示长度 =2048 时,它快 高达 13.08 倍。
令牌生成吞吐量是生成的前 256 个令牌的平均吞吐量。使用 float16 的 ONNX Runtime 平均比 torch.compile 快 6.6 倍,最高可达 18.55 倍。它也比 Llama.cpp 快 高达 1.64 倍。


ORT 在 int4 下的性能提升
ORT 支持 int4 量化。使用 int4 量化的 ORT 与 PyTorch 相比,性能可提升 高达 20.48 倍。与 Llama.cpp 相比,平均提升 3.9 倍,对于大序列长度,快 高达 13.42 倍。由于采用了特殊的 GemV 内核,使用 int4 量化的 ONNX Runtime 通常在批量大小为 1 时性能最佳。


- Phi-2 基准测试在 1 个 A100 GPU (SKU: Standard_ND96amsr_A100_v4) 上完成。软件包:torch: 2.3.0. dev20231221+cu121; pytorch-triton: 2.2.0+e28a256d71;ort-nightly-gpu: 1.17.0.dev20240118001;deepspeed: 0.12
- 批量(Batch)是指一组长度不同的输入句子;提示长度(Prompt length)是指输入文本的大小或长度。
这是一个使用 Olive 对 Phi-2 进行优化的示例,该示例利用本博客中强调的 ONNX Runtime 优化,并使用易于使用的硬件感知模型优化工具 Olive。
训练
除了推理,ONNX Runtime 还为 Phi-2 和其他 LLM 提供了训练加速。ORT 训练是 PyTorch 生态系统的一部分,可通过 torch-ort python 包作为 Azure Container for PyTorch (ACPT) 的一部分获取。它提供灵活且可扩展的硬件支持,相同的模型和 API 可在 NVIDIA 和 AMD GPU 上工作。ORT 通过优化的内核和内存优化加速训练,在减少大型模型训练的端到端训练时间方面显示出显著的提升。这仅需修改模型中的几行代码,使用 ORTModule API 进行封装即可。它还可以与 DeepSpeed 和 Megatron 等流行的加速库组合使用,以实现更快、更高效的训练。
Open AI 的 Triton 是一种领域特定语言和编译器,用于编写高效的自定义深度学习原语。ORT 支持 Open AI Triton 集成(ORT+Triton),将所有逐元素运算符转换为 Triton op,并在 Triton 中创建自定义融合内核。
ORT 还执行稀疏性优化,评估输入数据的稀疏性并利用这种稀疏性进行图优化。这减少了计算 FLOP 需求并提高了性能。
基于低秩适配器 (LoRA) 的微调通过仅训练少量附加参数(即适配器)同时冻结原始模型权重,使得训练更加高效。这些适配器使模型适应特定任务。量化和 LoRA (QLoRA) 将量化与 LoRA 结合,其中权重使用更少的位表示,同时保持模型的性能和质量。ONNX Runtime 训练可以与 LoRA 和 QLoRA 结合,从而提高 LLM 的内存效率和训练时间。LoRA 和 QLoRA 技术使得 LLM 等非常大的模型能够适应 GPU 内存,从而高效地完成训练。
使用 ORT 训练的 Phi-2 模型与 PyTorch Eager 模式和 torch.compile 相比显示出性能提升。Phi-2 使用合成数据集和网络数据集的混合进行训练。我们衡量了 ORT 和 ORT+Triton 模式下的性能提升,且提升随着批量大小的增加而增加。该模型使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch,批量大小逐渐增加。V100 和 A100 上的提升总结在下面的图表中。
训练基准测试在 8 个 V100 上运行,衡量吞吐量(迭代/秒,越高越好)

以下训练基准测试在 2 个 A100 上运行,衡量吞吐量(迭代/秒,越高越好)

Mistral
推理
Mistral7B 是一个预训练的生成式文本 LLM,拥有 70 亿参数。ONNX Runtime 对 float16 和 int4 模型的 Mistral 均显著提高了推理性能。对于 float16,ONNX Runtime 比 Llama.cpp 快 高达 9.46 倍。使用 int4 量化时,批量大小为 1 的令牌生成吞吐量显著提高,比 PyTorch Eager 快 高达 18.25 倍。




您现在可以在 Huggingface 上获取优化后的 Mistral 模型,点击此处。
训练
与 Phi-2 类似,Mistral 也受益于使用 ORT 进行训练加速。我们使用以下配置训练了 Mistral-7B,以观察使用 ORT 时的性能提升,包括与 LoRA 和 QLoRA 组合时。该模型使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch,批量大小为 1。

CodeLlama
Codellama-70B 是一个基于 Llama-2 平台开发的专注于编程的模型。该模型可以生成代码,并用自然语言讨论代码。由于 CodeLlama-70B 是一个经过微调的 Llama 模型,现有的优化可以直接应用。我们将一个 4 位量化的 ONNX 模型与 PyTorch Eager 和 Llama.cpp 进行了比较。对于提示吞吐量,ONNX Runtime 在所有批量大小下都比 PyTorch Eager 快至少 1.4 倍。ONNX Runtime 生成令牌的平均速度比任何批量大小的 PyTorch Eager 高 3.4 倍,比批量大小为 1 的 Llama.cpp 高 1.5 倍。


SD-Turbo 和 SDXL-Turbo
ONNX Runtime 在与 SD Turbo 和 SDXL Turbo 一起使用时,提供了推理性能优势,并且使得这些模型可以通过 Python 以外的语言(如 C# 和 Java)进行访问。在所有评估的(批量大小,步数)组合下,ONNX Runtime 的吞吐量均高于 PyTorch,其中 SDXL Turbo 模型吞吐量提升 高达 229%,SD Turbo 模型提升 120%。ONNX Runtime CUDA 特别擅长处理动态形状,但在静态形状方面也比 PyTorch 具有显著优势。

要了解更多关于使用 ONNX Runtime 加速 SD-Turbo 和 SDXL-Turbo 推理的信息,请查看我们最近与 Hugging Face 发布的博客,点击此处。
Llama-2
我们在此处发布了一篇独立的博客,介绍使用 ORT 对 Llama-2 推理进行的改进。点击此处。此外,Llama-2-7B 和 Llama-2-13B 在使用 ORT 进行训练时显示出良好的提升,尤其是与 LoRA 和 QLoRA 结合时。这些脚本 可以作为使用 Optimum 和 ORT 微调 Llama-2 的示例。下面的数字是使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch、批量大小为 1 的 Llama-2 模型在使用 ORT 进行训练时的结果。

Orca-2
推理
Orca-2 是一个仅用于研究的系统,可在用户提供数据推理、理解文本、解决数学问题和文本摘要等任务中提供一次性答案。Orca-2 有两个版本(70 亿和 130 亿参数);它们都是通过在定制的高质量人工数据上微调相应的 Llama-2 基础模型而构建的。ONNX Runtime 通过使用图形融合和内核优化(类似于 Llama-2 的优化)帮助优化 Orca-2 的推理。
ORT 在 int4 下的性能提升
Orca-2-7B int4 量化性能比较显示,提示吞吐量性能提升 高达 26 倍,令牌生成吞吐量比 PyTorch 提升 高达 16.5 倍。与 Llama.cpp 相比,提示吞吐量也提升了 4.75 倍以上,令牌生成吞吐量提升了 3.64 倍以上。




使用 ONNX runtime float16 的 Orca-2 7b 性能比较也显示出在提示和令牌生成吞吐量方面的显著提升。




Orca-2 基准测试在 1 个 A100 GPU (SKU: Standard_ND96amsr_A100_v4) 上完成,软件包:torch 2.2.0, triton 2.2.0, onnxruntime-gpu 1.17.0, deepspeed 0.13.2, llama.cpp - commit 594fca3fefe27b8e95cfb1656eb0e160ad15a793, transformers 4.37.2
训练
Orca-2-7B 也受益于使用 ORT 进行训练加速。我们使用 LoRA 并启用了稀疏性优化,对 Orca-2-7B 模型以 512 的序列长度进行了训练,看到了良好的性能提升。下面的数字是使用 DeepSpeed Stage-2 在 wikitext 数据集上训练了 5 个 epoch、批量大小为 1 的 Orca-2-7B 模型在使用 ORT 进行训练时的结果。

Gemma
Gemma 是一个轻量级、开放的模型系列,基于 Google 用于创建 Gemini 模型的研究和技术构建而成。它提供两种尺寸:2B 和 7B。每种尺寸都发布了预训练和指令微调版本。ONNX Runtime 可用于优化并高效运行任何开源模型。我们针对 Gemma-2B 模型进行了基准测试,使用 float16 的 ONNX Runtime 比 PyTorch Compile 快 高达 7.47 倍,比 Llama.cpp 快 高达 3.47 倍。使用 int4 量化的 ORT 比 PyTorch Eager 快 高达 19.81 倍,比 Llama.cpp 快 2.62 倍。


结论
总之,ONNX Runtime (ORT) 为包括 Phi-2、Mistral、CodeLlama、SDXL-Turbo、Llama-2、Orca-2 和 Gemma 在内的多种模型提供了显著的性能提升。ORT 提供最先进的融合和内核优化,包括支持 float16 和 int4 量化,从而实现更快的推理速度和更低的成本。在提示和令牌生成吞吐量方面,ORT 优于 PyTorch 和 Llama.cpp 等其他框架。ORT 在训练 LLM 方面也显示出显著优势,批量大小越大提升越大,并且与最先进的技术很好地结合,实现高效的大型模型训练。