借助 ONNX Runtime 加速 LLaMA-2 推理

由: Kunal VaishnaviParinita Rahi

2023年11月14日 (更新于11月22日)

有兴趣更快地运行 Llama2 吗?让我们探讨一下 ONNX Runtime 如何推动您的 Llama2 变体实现更快的推理!

借助 ONNX Runtime 最先进的融合和内核优化,您现在可以体验到 7B、13B 和 70B 模型显著的推理性能提升,最高可达 3.8 倍。本博客详细介绍了性能增强功能,深入探讨了 ONNX Runtime 融合优化、多 GPU 推理支持,并指导您如何利用 ONNX Runtime 的跨平台能力实现跨平台的无缝推理。这是即将发布的系列博客中的第一篇,后续将涵盖 ONNX Runtime 量化更新带来的高效内存使用以及跨平台使用场景的其他方面。

背景:Llama2 和 Microsoft

Llama2 是 Meta 开源的一种最先进的 LLM,其规模从 7B 到 70B 参数不等(7B、13B、70B)。Microsoft 和 Meta 于 2023 年 7 月宣布了他们在 Azure 和 Windows 上的人工智能合作。作为该公告的一部分,Llama2 被添加到 Azure AI 模型目录中,该目录是基础模型的中心,使开发人员和机器学习 (ML) 专业人员能够轻松发现、评估、自定义和大规模部署预构建的大型 AI 模型。

ONNX Runtime 允许用户通过改进的优化轻松地将这种生成式 AI 模型的功能集成到您的应用程序和服务中,从而提高推理速度并降低成本。

借助新的 ONNX Runtime 优化实现更快的推理

作为新的 1.16.2 版本的一部分,ONNX Runtime 现在为 Llama2 提供了多项内置优化,包括图融合和内核优化。与 PyTorch 编译模式下针对 CUDA FP16 提示延迟的 Hugging Face (HF) Llama2 变体相比,推理速度提升如下所述。下文所示的端到端吞吐量或实际运行时间吞吐量定义为 批量大小 * (提示长度 + 令牌生成长度) / 实际运行时间延迟,其中实际运行时间延迟 = 端到端运行产生的延迟,令牌生成长度 = 256 个生成的令牌。与 PyTorch 编译模式相比,端到端吞吐量对于 13B 模型提高了 2.4 倍,对于 7B 模型提高了 1.8 倍。对于更高的批量大小、序列长度对,例如 (16, 2048),PyTorch eager 会超时,而 ORT 则显示出比编译模式更好的性能。

E2E Throughput Comparisons - Llama-2-7b
E2E Throughput Comparisons - Llama-2-13b
图 1:端到端吞吐量比较

延迟和吞吐量

下面的图表显示了 ONNX Runtime 和 PyTorch 版本的 Llama2 7B 模型在 CUDA FP16 上的延迟比较。这里的延迟定义为模型完成一次前向传播以生成 logits 并同步输出所需的时间。

Prompt Latency Comparisons - Llama-2-7b
Prompt Latency Comparisons - Llama-2-13b
图 2:提示延迟比较

下面的令牌生成吞吐量是生成的前 256 个令牌的平均吞吐量。与 PyTorch 编译模式相比,我们看到令牌生成吞吐量最高可达 ~1.3 倍(7B)和 ~1.5 倍(13B)的提升。

Tokens Generated Throughput Comparisons - Llama-2-7b
Tokens Generated Throughput Comparisons - Llama-2-13b
图 3:生成的令牌吞吐量比较

这些指标的更多详细信息可以在此处找到。

带有多 GPU 推理的 ONNX Runtime

ONNX Runtime 支持多 GPU 推理,以支持大型模型的服务。即使在 FP16 精度下,LLaMA-2 70B 模型也需要 140GB 显存。加载模型需要多个 GPU 进行推理,即使使用强大的 NVIDIA A100 80GB GPU 也是如此。

ONNX Runtime 对 70B 模型应用了 Megatron-LM 张量并行性,将原始模型权重分割到不同的 GPU 上。对 70B 模型进行的 Megatron 分片将 FP16 精度的 PyTorch 模型分割成 4 个分区,将每个分区转换为 ONNX 格式,然后对转换后的 ONNX 模型应用新的 ONNX Runtime 图融合。通过这些优化,70B 模型在批量大小为 1 时,令牌生成的吞吐量约为每秒 30 个令牌,并且对于较小的序列长度,端到端吞吐量从 30 tps 开始。您可以在此处找到更多示例脚本。

70B Llama2 Model Throughput
图 4:70B Llama2 模型吞吐量

ONNX Runtime 优化

LLaMA-2 Optimization Diagram
图 5:LLaMA-2 优化图

ONNX Runtime 用于优化(例如图融合)的技术适用于最先进的模型。随着这些模型变得越来越复杂,用于应用图融合的技术也会进行调整以适应额外的复杂性。例如,ONNX Runtime 现在支持自动化模式匹配,而不是手动匹配图中的融合模式。不再需要手动检测大型子图并匹配它们形成的许多路径,而是可以通过将大型模块导出为函数,然后根据函数的规范进行模式匹配来识别融合机会。

Example of Rotary Embedding Function
图 6:旋转嵌入函数示例

举一个具体的例子,图 6 是组成旋转嵌入计算的节点示例。由于需要验证的路径数量众多,对此子图进行模式匹配非常麻烦。通过将其导出为函数,图的父视图将只显示输入和输出,并将所有这些节点表示为单个算子。

Example of Rotary Embedding Function in Parent Graph
图 7:父图中的旋转嵌入函数示例

这种方法使得维护和支持未来版本的旋转嵌入计算变得更加容易,因为模式匹配仅取决于算子的输入和输出,而不是其内部语义表示。它还允许在类似模型(如 GPT-NeoX、Falcon、Mistral、Zephyr 等)中旋转嵌入的其他现有实现进行模式匹配和融合,而无需进行或只需少量更改。

ONNX Runtime 还增加了对 GroupQueryAttention (GQA) 算子的支持,该算子利用新的 Flash Attention V2 算法及其优化的内核来高效地计算注意力。GQA 算子支持过去键/值缓存(past KV cache)和当前键/值缓存(present KV cache)之间的 past-present 缓冲区共享。通过将当前 KV 缓存绑定到过去 KV 缓存,无需为两个缓存分配独立的设备内存。相反,过去 KV 缓存可以预先分配足够的设备内存,以便在推理过程中无需请求新的设备内存。这减少了在计算密集型工作负载期间 KV 缓存变大时的内存使用,并通过消除设备内存分配请求来降低延迟。past-present 缓冲区共享可以在不更改 ONNX 模型的情况下启用或禁用,从而为最终用户提供了更大的灵活性来决定哪种方法最适合他们。

除了这些融合和内核优化之外,ONNX Runtime 还减少了模型的内存使用。除了量化改进(这将在未来的文章中介绍)之外,ONNX Runtime 还将每个旋转嵌入中使用的余弦和正弦缓存的大小压缩了 50%。ONNX Runtime 中运行旋转嵌入计算的计算内核可以识别这种格式,并使用它们的并行化实现以更少的内存使用更有效地计算旋转嵌入。旋转嵌入计算内核还支持交错和非交错格式,以便分别支持 Microsoft 版本的 LLaMA-2 和 Hugging Face 版本的 LLaMA-2,同时共享相同的计算。

这些优化适用于Hugging Face 版本(以 -hf 结尾的模型)和 Microsoft 版本。您可以从Microsoft 的 LLaMA-2 ONNX 仓库下载优化的 HF 版本。敬请期待即将推出的更新的 Microsoft 版本!

使用 Olive 优化您自己的模型

Olive 是一款硬件感知模型优化工具,它集成了模型压缩、优化和编译等先进技术。我们已通过 Olive 提供 ONNX Runtime 优化,因此您可以使用简单的体验为给定硬件简化整个优化过程。

此处提供了一个使用 Olive 优化 Llama2 的示例,该示例利用了本博客中重点介绍的 ONNX Runtime 优化。不同的优化流程可以满足各种需求。例如,您可以根据您的精度容忍度,灵活选择用于 CPU 和 GPU 推理量化的不同数据类型。此外,您可以使用 Olive-QLoRa 在客户端 GPU 上微调您自己的 Llama2 模型,并使用 ONNX Runtime 优化进行推理。

使用示例

此处是一个示例 notebook,它展示了如何在您的应用程序中使用上述 ONNX Runtime 优化的端到端示例。

结论

本博客中讨论的进步通过 ONNX Runtime 提供了更快的 Llama2 推理,为 AI 应用和研究提供了令人兴奋的可能性。随着性能和效率的提高,创新的视野广阔,我们热切期待由活跃的开发社区使用 Llama2 和 ONNX Runtime 构建的新应用。敬请关注更多更新!