在边缘运行 PyTorch 模型
作者:Natalie Kershaw 和 Prasanth Pulavarthi
2023年10月12日
大多数现代机器学习模型都使用 PyTorch 开发。PyTorch 在创建和训练模型方面提供的敏捷性和灵活性使其成为当今最流行的深度学习框架。典型的工作流程是在云端训练这些模型,并也在云端运行它们。然而,许多场景正在出现,使得在设备本地运行模型更具吸引力,甚至在某些情况下成为必需。这些场景包括:
- 避免与云端的网络往返(例如在音频和视频处理中)
- 将用户数据保留在设备上(用于隐私保护或满足法规要求)
- 云资源成本高昂(尤其是在设备能力未得到充分利用时)
- 应用程序需要在没有互联网连接的情况下运行

在本文中,我们将揭秘如何在边缘运行 PyTorch 模型。我们将“边缘”定义为云端之外的任何地方,从资源充足的大型个人电脑到手机等小型设备。过去,完成这项任务一直充满挑战,但模型优化和 ONNX Runtime 等软件的新进展使其变得更加可行——即使是对于 Stable Diffusion、Whisper 和 Llama2 等新型生成式 AI 和大型语言模型。
在边缘运行 PyTorch 模型的考量因素
在考虑在边缘运行 PyTorch 模型时,有几个因素需要牢记。
- 大小:现代模型可以达到几千兆字节(因此得名大型语言模型!)。在云端,模型大小通常不是一个问题,除非它变得太大而无法容纳在单个 GPU 上。那时,有各种成熟的解决方案可以在多个 GPU 上运行。对于边缘设备,我们需要找到能够适应设备限制的模型。这有时需要权衡模型质量。大多数现代模型都有多种尺寸(例如10亿参数、130亿参数、700亿参数等),因此您可以选择适合您设备的变体。通常会应用量化等技术来减少表示参数的位数,从而进一步减小模型大小。应用程序的大小也受到应用商店的限制,因此引入几千兆字节的库在边缘设备上是不可行的。
- 应用程序集成 API:在云端,模型通常被打包成 Docker 容器,这些容器公开一个供应用程序或服务调用的端点。在边缘设备上,Docker 容器可能会占用过多资源,甚至可能不受支持。通过使用像 ONNX Runtime 这样的优化引擎,可以消除对 Python 和 Docker 容器的依赖。ONNX Runtime 还提供多种语言的 API,包括 C、C++、C#、Rust、Java、JavaScript、Objective-C 和 Swift,以便更轻松地与宿主应用程序进行原生集成。
- 性能:在云端,凭借大量内存、无功耗限制和强大的计算能力,运行未优化的模型是可能的。在边缘设备上,这些“奢侈品”不存在,因此优化至关重要。例如,ONNX Runtime 优化内存分配、融合模型操作符、减少内核启动时间、最小化处理单元之间的张量传输,并应用经过调优的矩阵数学算法。它还能够利用设备特定的编译器和引擎,为您的应用程序提供通用接口,同时在每个设备上发挥最佳性能。
- 可维护性:在云端,更新模型就像部署新的容器镜像和增加流量一样简单。在边缘端,您需要考虑如何分发模型更新。有时这涉及向应用商店发布更新,有时可能需要在您的应用程序中实现数据更新机制并下载新的模型文件,甚至只下载增量更新。有许多可能的途径,因此本文不会深入探讨此主题,但这是您在规划生产时需要牢记的一个方面。
- 混合模式:您可以选择同时利用云端和设备端,而非仅选择其一。如今,Office 等应用程序在生产中使用了多种混合模式。一种模式是根据网络条件或输入特性动态决定在设备上还是在云端运行。另一种模式是在设备上运行模型管道的一部分,在云端运行另一部分。这对于具有独立编码器和解码器阶段的现代模型管道尤其有用。使用像 ONNX Runtime 这样同时支持云端和设备端的引擎可以简化开发。我们将在后续文章中更详细地讨论混合场景。
- 个性化:在许多情况下,PyTorch 模型只是简单地在设备上运行。然而,您也可能遇到需要在设备上个性化模型而无需将数据发送到云端的场景。推荐和内容定向就是可以通过根据设备上的活动更新模型来提高质量的示例场景。在设备上使用 PyTorch 进行微调和训练可能不可行(由于性能和大小问题),但使用像 ONNX Runtime 这样的引擎可以允许 PyTorch 模型在本地进行更新和个性化。相同的机制还支持联邦学习,这有助于减轻用户数据暴露的风险。
在边缘运行 PyTorch 模型的工具
我们前面多次提到 ONNX Runtime。ONNX Runtime 是一个紧凑、基于标准的引擎,与 PyTorch 深度集成。通过使用 PyTorch 的 ONNX API,您的 PyTorch 模型可以在各种边缘设备上使用 ONNX Runtime 运行。
在边缘运行 PyTorch 模型的第一步是将其转换为轻量级格式,使其不再需要 PyTorch 框架及其数千兆字节的依赖。PyTorch 已经考虑到了这一点,并提供了一个专门实现此功能的 API——torch.onnx。ONNX 是一个开放标准,定义了构成模型的运算符。PyTorch ONNX API 将 Python 风格的 PyTorch 代码转换为功能图,该图捕获了无需 Python 即可运行模型所需的运算符。像机器学习中的所有事物一样,也存在一些需要注意的限制。某些 PyTorch 模型无法表示为单个图——在这种情况下,您可能需要输出多个图并在自己的管道中将它们拼接起来。
流行的 Hugging Face 库也提供了基于 torch.onnx 功能的 API,用于将模型导出为 ONNX 格式。超过 130,000 个模型受支持,这意味着您关心的模型很可能就在其中。
在本文中,我们将通过多种语言(从 C# 到 JavaScript 再到 Swift),向您展示在流行设备(如 Windows 笔记本电脑、手机和网页浏览器)上运行最先进的 PyTorch 模型(如 Whisper 和 Stable Diffusion)的几个示例。
在边缘运行 PyTorch 模型的示例
在 Windows 上运行 Stable Diffusion
Stable Diffusion 管道由五个 PyTorch 模型组成,这些模型根据文本描述生成图像。扩散过程会迭代随机像素,直到输出图像与描述匹配。
为了在边缘运行,其中四个模型可以从 HuggingFace 导出为 ONNX 格式。
from optimum.onnxruntime import ORTStableDiffusionPipeline
pipeline = ORTStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", export=True)
pipeline.save_pretrained("./onnx-stable-diffusion")
您无需导出第五个模型 ClipTokenizer,因为它在 ONNX Runtime 扩展中可用,该库用于 PyTorch 模型的预处理和后处理。
为了将这一系列模型作为 .NET 应用程序运行,我们使用 C# 构建了管道代码。如果您的机器上有可用的 CPU、GPU 或 NPU,此代码可以使用 ONNX Runtime 的设备专用硬件加速器在其上运行。这通过下面的 ExecutionProviderTarget
进行配置。
static void Main(string[] args)
{
var prompt = "Two golden retriever puppies playing in the grass.";
var config = new StableDiffusionConfig
{
NumInferenceSteps = 50,
GuidanceScale = 7.5,
ExecutionProviderTarget = StableDiffusionConfig.ExecutionProvider.Cpu,
DeviceId = 0,
TokenizerOnnxPath = ".\models\tokenizer\model.onnx",
TextEncoderOnnxPath = ".\models\text_encoder\model.onnx",
UnetOnnxPath = ".\models\unet\model.onnx",
VaeDecoderOnnxPath = ".\models\vae_decoder\model.onnx",
SafetyModelPath = ".\models\safety_checker\model.onnx",
};
var image = UNet.Inference(prompt, config);
if (image == null)
{
Console.WriteLine("Unable to create image, please try again.");
}
}
这是模型管道的输出,以 50 次推理迭代运行

您可以按照此 教程中显示的详细步骤,在 Windows 上构建并运行该应用程序。
浏览器中的文本生成
使用 transformers.js 库,在浏览器中本地运行 PyTorch 模型不仅可能,而且非常简单。Transformers.js 使用 ONNX Runtime Web 作为其后端。许多模型已经转换为 ONNX 格式并通过 transformers.js CDN 提供服务,使得在浏览器中进行推理只需编写几行 HTML 代码即可。
<html>
<body>
<h1>Enter starting text …</h1>
<form id="form">
<input type="text" id="inputText">
<button type="submit" id="submitButton">Submit</button>
</form>
<div id="output"></div>
<script type="module">
import { pipeline } from 'https://cdn.jsdelivr.net.cn/npm/@xenova/transformers@2.6.2';
let inputText = document.getElementById('inputText');
let outputDiv = document.getElementById('output');
let submitButton = document.getElementById('submitButton');
submitButton.addEventListener('click', async (e) => {
e.preventDefault();
let generator = await pipeline('text-generation', 'Xenova/LaMini-Neo-125M');
let result = await generator(inputText.value,
{ max_new_tokens: 200,
temperature: 2,
repetition_penalty: 1.5,
no_repeat_ngram_size: 2,
num_beams: 2,
num_return_sequences: 1,
});
outputDiv.innerHTML = result[0].generated_text;
});
</script>
</body>
</html>
您还可以使用纯 JavaScript 或在 React 或 Next.js 等 Web 应用程序中嵌入对 transformers 管道的调用,或者编写浏览器扩展。
ONNX Runtime Web 目前使用 WebAssembly 在 CPU 上执行模型。这对于许多模型来说已经足够,但如果设备上存在 GPU,利用 GPU 可以改善用户体验。ONNX Runtime Web 对 WebGPU 的支持即将推出,这将使您能够利用 GPU,同时使用相同的推理 API。

在移动设备上使用 Whisper 进行语音识别
OpenAI 的 Whisper 是一个 PyTorch 语音识别模型。Whisper 有多种尺寸变体——最小的 Whisper Tiny 适合在移动设备上运行。使用 Olive 框架可以将 Whisper Tiny 模型的所有组件(音频解码器、编码器、解码器和文本序列生成)组合并导出为单个 ONNX 模型。要将此模型作为移动应用程序的一部分运行,您可以使用 ONNX Runtime Mobile,它支持 Android、iOS、React Native 和 MAUI/Xamarin。
ONNX Runtime Mobile 通过 NNAPI(在 Android 上)、CoreML(在 iOS 上)和 XNNPACK(在 iOS 和 Android 上)支持硬件加速。
下面显示了一个在短音频样本上执行语音转录的 Android 移动应用示例的相关代码片段。
init {
val env = OrtEnvironment.getEnvironment()
val sessionOptions = OrtSession.SessionOptions()
sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
session = env.createSession(modelBytes, sessionOptions)
val nMels: Long = 80
val nFrames: Long = 3000
baseInputs = mapOf(
"min_length" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"max_length" to createIntTensor(env, intArrayOf(200), tensorShape(1)),
"num_beams" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"num_return_sequences" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
"length_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
"repetition_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
)
}
data class Result(val text: String, val inferenceTimeInMs: Long)
fun run(audioTensor: OnnxTensor): Result {
val inputs = mutableMapOf()
baseInputs.toMap(inputs)
inputs["audio_pcm"] = audioTensor
val startTimeInMs = SystemClock.elapsedRealtime()
val outputs = session.run(inputs)
val elapsedTimeInMs = SystemClock.elapsedRealtime() - startTimeInMs
val recognizedText = outputs.use {
@Suppress("UNCHECKED_CAST")
(outputs[0].value as Array>)[0][0]
}
return Result(recognizedText, elapsedTimeInMs)
}
您可以录制一小段音频片段进行转录。

在移动设备上训练模型以识别您的声音
ONNX Runtime 还可以接受预训练模型并使其适应新数据。它可以在边缘端做到这一点——特别是在移动设备上,在那里很容易录制您的声音、访问您的照片和其他个性化数据。重要的是,您的数据在训练期间不会离开设备。
例如,您可以训练一个 PyTorch 模型,在您的手机上仅识别您自己的声音,用于身份验证场景。
PyTorch 模型是在您的开发环境中从 HuggingFace 获取的,并添加了额外的层来执行说话人分类。
from transformers import Wav2Vec2ForSequenceClassification, AutoConfig
import torch
config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid")
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid")
model.classifier = torch.nn.Linear(256, 2)
训练所需的模型和其他组件(用于衡量模型质量的损失函数和用于指导训练期间权重调整的优化器)均使用 ONNX Runtime Training 导出。
artifacts.generate_artifacts(
onnx_model,
requires_grad=requires_grad,
frozen_params=frozen_params,
loss=CustomCELoss(),
optimizer=artifacts.OptimType.AdamW,
artifact_directory="MyVoice/artifacts",
)
这组工件现在已准备好由移动应用程序加载,这里以 iOS Swift 代码的形式显示。该应用程序会要求用户提供语音样本,模型将使用这些样本进行训练。
func trainStep(inputData: [Data], labels: [Int64]) throws {
let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)]
try trainingSession.trainStep(withInputValues: inputs)
try trainingSession.optimizerStep()
try trainingSession.lazyResetGrad()
}
模型训练完成后,您可以运行它来验证语音样本是否是您本人!

您可以阅读完整的 说话人验证教程,并从源代码构建并运行该应用程序。
下一步是什么?
在本文中,我们展示了为何要在边缘运行 PyTorch 模型以及需要考虑的方面。我们还分享了几个包含代码的示例,您可以用于使用 ONNX Runtime 在边缘运行最先进的 PyTorch 模型。我们还展示了 ONNX Runtime 如何为性能和跨平台执行而构建,使其成为在边缘运行 PyTorch 模型的理想方式。使用 ONNX Runtime 在边缘运行 PyTorch 模型,尽情享受吧!
您可能已经注意到,尽管 ONNX Runtime 经过优化可以运行 Llama2,但我们并未包含 Llama2 的示例。那是因为出色的 Llama2 模型值得单独撰写一篇文章,敬请期待!