使用 ONNX Runtime 加速 Phi-2、CodeLlama、Gemma 和其他 Gen AI 模型
作者:
Parinita Rahi, Sunghoon Choi, Yufeng Li, Kshama Pawar, Ashwini Khade, Ye Wang2024 年 2 月 26 日
在速度和效率至上的快节奏环境中,ONNX Runtime (ORT) 使用户能够轻松地将生成式 AI 模型的强大功能集成到他们的应用程序和服务中,并通过改进的优化来提高推理速度并有效降低成本。这些优化包括最先进的融合和内核优化,以帮助提高模型性能。ONNX Runtime 1.17 最新版本改进了多种 Gen AI 模型的推理性能,包括 Phi-2、Mistral、CodeLlama、Orca-2 等。ONNX Runtime 是小型语言模型 (SLM) 从训练到推理的完整解决方案,与其他框架相比,显示出显著的加速。凭借对 float32、float16 和 int4 的支持,ONNX Runtime 的推理增强功能提供了最大的灵活性和性能。
在本博客中,我们将介绍最新的 GenAI 模型(如 Phi-2、Mistral、CodeLlama、SD-Turbo、SDXL-Turbo、Llama2 和 Orca-2)在训练和推理方面的显著优化加速。对于这些模型架构,与 PyTorch 和 Llama.cpp 等其他框架相比,ONNX Runtime 在各种批大小和提示长度下显著提高了性能。现在也可以使用 Olive 来实现使用 ONNX Runtime 的这些优化。
快速链接
Phi-2
Phi-2 是 Microsoft 开发的 27 亿参数 Transformer 模型。它是一个 SLM,具有出色的推理和语言理解能力。凭借其小巧的尺寸,Phi-2 为研究人员提供了一个绝佳的平台,他们可以探索各个方面,例如机制可解释性、安全改进以及针对不同任务的微调实验。
ONNX Runtime 1.17 引入了支持 Phi-2 模型的内核更改,包括针对 Phi-2 的 Attention、Multi-Head Attention、Grouped-Query Attention 和 RotaryEmbedding 的优化。具体来说,已添加对以下内容的支持:
- Multi-Head Attention CPU 内核中的因果掩码
- Attention 和 Rotary Embedding 内核中的 rotary_embedding_dim
- Grouped-Query Attention 内核中的 bfloat16
支持基于 TorchDynamo 的 Phi-2 ONNX 导出,并且优化脚本构建在其之上。
对于 Phi-2 推理,对于所有提示长度,使用 float16 和 int4 量化的 ORT 比使用 float32、PyTorch 和 Llama.cpp 的 ORT 性能更好。
推理
ORT 使用 float16 的增益
优化的 CUDA 性能,用于提示吞吐量(即模型根据输入提示处理和生成响应的速率)比 PyTorch Compile 快高达 7.39 倍。我们还观察到,与 Llama.cpp 相比,ONNX Runtime 在更大的批大小和提示长度下明显更快。例如,对于批大小 = 16,提示长度 = 2048,速度快高达 13.08 倍。
令牌生成吞吐量是生成的前 256 个令牌的平均吞吐量。使用 float16 的 ONNX Runtime 平均比 torch.compile 快6.6 倍,最高可达 18.55 倍。它的性能也比 Llama.cpp 快高达 1.64 倍。


ORT 使用 int4 的增益
ORT 提供对 int4 量化的支持。使用 int4 量化的 ORT 与 PyTorch 相比,性能可以提高高达 20.48 倍。平均而言,它比 Llama.cpp 好 3.9 倍,对于大序列长度,速度快高达 13.42 倍。由于 GemV 的特殊内核,使用 int4 量化的 ONNX Runtime 通常在批大小为 1 时性能最佳。


- Phi-2 基准测试在 1 个 A100 GPU 上完成(SKU:Standard_ND96amsr_A100_v4)。软件包:torch:2.3.0. dev20231221+cu121;pytorch-triton:2.2.0+e28a256d71;ort-nightly-gpu:1.17.0.dev20240118001;deepspeed:0.12
- 批次是一组不同长度的输入句子;提示长度是指输入文本的大小或长度。
这是一个 使用 Olive 进行 Phi-2 优化的示例,它利用了本博客中强调的 ONNX Runtime 优化,使用了易于使用的硬件感知模型优化工具 Olive。
训练
除了推理之外,ONNX Runtime 还为 Phi-2 和其他 LLM 提供训练加速。ORT 训练是 PyTorch 生态系统的一部分,可通过 torch-ort python 包作为 Azure Container for PyTorch (ACPT) 的一部分提供。它提供灵活且可扩展的硬件支持,相同的模型和 API 可与 NVIDIA 和 AMD GPU 一起使用。ORT 通过优化的内核和内存优化来加速训练,这在减少大型模型训练的端到端训练时间方面显示出显著的收益。这涉及到在模型中更改几行代码以使用 ORTModule API 对其进行包装。它还可以与 DeepSpeed 和 Megatron 等流行的加速库组合使用,以实现更快、更高效的训练。
Open AI 的 Triton 是一种特定领域的语言和编译器,用于编写高效的自定义深度学习原语。ORT 支持 Open AI Triton 集成 (ORT+Triton),其中所有逐元素运算符都转换为 Triton ops,并且 ORT 在 Triton 中创建自定义融合内核。
ORT 还执行稀疏性优化,以评估输入数据稀疏性并执行利用此稀疏性的图优化。这降低了计算 FLOP 要求并提高了性能。
基于低秩适配器 (LoRA) 的微调通过仅训练少量附加参数(适配器)同时冻结原始模型的权重来提高训练效率。这些适配器使模型适应特定任务。量化和 LoRA (QLoRA) 将量化与 LoRA 相结合,其中权重使用更少的位来表示,同时保持模型的性能和质量。ONNX Runtime 训练与 LoRA 和 QLoRA 结合使用,以提高 LLM 的内存效率和训练时间加速。LoRA 和 QLoRA 技术使 LLM 等超大型模型能够适应 GPU 内存,从而高效地完成训练。
使用 ORT 训练的 Phi-2 模型显示出优于 PyTorch Eager 模式和 torch.compile 的性能增益。Phi-2 是使用合成数据集和 Web 数据集的混合进行训练的。我们衡量了针对 ORT 和 ORT+Triton 模式的增益,并且增益随着批大小的增加而增加。该模型使用 DeepSpeed Stage-2 训练了 5 个 epoch,在 wikitext 数据集上增加了批大小。下图中总结了 V100 和 A100 的增益。
训练基准测试在 8 个 V100 上运行,并以迭代/秒为单位衡量吞吐量(越高越好)

以下训练基准测试在 2 个 A100 上运行,并以迭代/秒为单位衡量吞吐量(越高越好)

Mistral
推理
Mistral7B 是一个预训练的生成文本 LLM,具有 70 亿个参数。ONNX Runtime 显著提高了 Mistral 在 float16 和 int4 模型下的推理性能。使用 float16,ONNX Runtime 比 Llama.cpp 快高达 9.46 倍。对于批大小为 1 的情况,令牌生成吞吐量通过 int4 量化得到了显著提高,比 PyTorch Eager 快高达 18.25 倍。




您现在可以在 Huggingface 上访问优化的 Mistral 模型,点击此处。
训练
与 Phi-2 类似,Mistral 也受益于使用 ORT 进行的训练加速。我们使用以下配置训练了 Mistral-7B,以查看 ORT 的增益,包括与 LoRA 和 QLoRA 组合使用时的增益。该模型使用 DeepSpeed Stage-2 训练了 5 个 epoch,在 wikitext 数据集上批大小为 1。

CodeLlama
Codellama-70B 是一个基于 Llama-2 平台开发的面向编程的模型。此模型可以生成代码并以自然语言生成围绕代码的讨论。由于 CodeLlama-70B 是一个微调的 Llama 模型,因此可以直接应用现有的优化。我们将 4 位量化的 ONNX 模型与 PyTorch Eager 和 Llama.cpp 进行了比较。对于提示吞吐量,对于所有批大小,ONNX Runtime 比 PyTorch Eager 快至少 1.4 倍。对于任何批大小,ONNX Runtime 生成令牌的平均速度比 PyTorch Eager 高 3.4 倍,对于批大小为 1,比 Llama.cpp 高 1.5 倍。


SD-Turbo 和 SDXL-Turbo
ONNX Runtime 在与 SD Turbo 和 SDXL Turbo 一起使用时,可提供推理性能优势,并且还使模型可以在 Python 以外的语言(如 C# 和 Java)中访问。对于评估的所有(批大小,步数)组合,ONNX Runtime 实现了比 PyTorch 更高的吞吐量,SDXL Turbo 模型的吞吐量提高了高达 229%,SD Turbo 模型的吞吐量提高了 120%。ONNX Runtime CUDA 尤其擅长处理动态形状,但对于静态形状,它也显示出比 PyTorch 显著的优势。

要详细了解如何使用 ONNX Runtime 加速 SD-Turbo 和 SDXL-Turbo 推理,请查看我们最近与 Hugging Face 合作撰写的博客。
Llama-2
我们发布了一篇单独的博客,介绍了 ORT 在 Llama-2 推理方面的改进,点击此处查看。此外,Llama-2-7B 和 Llama-2-13B 在 ORT 训练方面表现出良好的增益,尤其是在与 LoRA 和 QLoRA 结合使用时。这些脚本可以用作示例,以使用 Optimum 和 ORT 微调 Llama-2。以下数字是 Llama-2 模型使用 ORT 和 DeepSpeed Stage-2 训练 5 个 epoch,在 wikitext 数据集上批大小为 1 的结果。

Orca-2
推理
Orca-2 是一个仅供研究的系统,它在诸如推理用户提供的数据、理解文本、解决数学问题和总结文本等任务中给出一次性答案。Orca-2 有两个版本(70 亿和 130 亿参数);它们都是通过在定制的高质量人工数据上微调相应的 Llama-2 基础模型而制成的。ONNX Runtime 有助于优化 Orca-2 推理,以使用图形融合和内核优化,例如 Llama-2 的优化。
ORT 使用 int4 的增益
Orca-2-7B int4 量化性能比较表明,提示吞吐量性能提高了高达 26 倍,令牌生成吞吐量提高了高达 16.5 倍(与 PyTorch 相比)。与 Llama.cpp 相比,提示吞吐量提高了 4.75 倍以上,令牌生成吞吐量提高了 3.64 倍。




Orca-2 7b 与 ONNX Runtime float16 的性能比较也显示出提示和令牌生成吞吐量的显著增益。




Orca-2 基准测试在 1 个 A100 GPU 上完成,SKU:Standard_ND96amsr_A100_v4,软件包 torch 2.2.0、triton 2.2.0、onnxruntime-gpu 1.17.0、deepspeed 0.13.2、llama.cpp - commit 594fca3fefe27b8e95cfb1656eb0e160ad15a793、transformers 4.37.2
训练
Orca-2-7B 也受益于使用 ORT 进行的训练加速。我们针对 512 的序列长度训练了 Orca-2-7B 模型,启用了 LoRA 和稀疏性优化,并看到了良好的性能增益。以下数字是 Orca-2-7B 模型使用 ORT 和 DeepSpeed Stage-2 训练 5 个 epoch,在 wikitext 数据集上批大小为 1 的结果。

Gemma
Gemma 是一个轻量级、开放模型的系列,它基于 Google 用于创建 Gemini 模型的研究和技术构建。它有两种尺寸:2B 和 7B。每种尺寸都发布了预训练和指令调优变体。ONNX Runtime 可用于优化和高效运行任何开源模型。我们针对 Gemma-2B 模型进行了基准测试,使用 float16 的 ONNX Runtime 比 PyTorch Compile 快高达 7.47 倍,比 Llama.cpp 快高达 3.47 倍。使用 int4 量化的 ORT 比 PyTorch Eager 快高达 19.81 倍,比 Llama.cpp 快 2.62 倍。


结论
总之,ONNX Runtime (ORT) 为多种模型提供了显著的性能改进,包括 Phi-2、Mistral、CodeLlama、SDXL-Turbo、Llama-2、Orca-2 和 Gemma。ORT 提供最先进的融合和内核优化,包括对 float16 和 int4 量化的支持,从而实现更快的推理速度和更低的成本。在提示和令牌生成吞吐量方面,ORT 优于 PyTorch 和 Llama.cpp 等其他框架。ORT 在 LLM 训练方面也显示出显著优势,批大小越大,增益越大,并且与最先进的技术很好地结合,从而实现高效的大型模型训练。