在边缘设备上运行 PyTorch 模型

由: Natalie KershawPrasanth Pulavarthi

2023 年 10 月 12 日

大多数现代 ML 模型都是使用 PyTorch 开发的。PyTorch 在创建和训练模型方面提供的敏捷性和灵活性使其成为当今最流行的深度学习框架。典型的工作流程是在云端训练这些模型,并也从云端运行它们。然而,许多场景的出现使得在设备本地运行模型更具吸引力 - 或者在某些情况下是必需的。这些场景包括:

  • 避免与云端的网络往返(例如在音频和视频处理中)
  • 将用户数据保留在设备上(为了隐私保护或法规要求)
  • 云资源成本高昂(尤其是在设备功能未得到充分利用时)
  • 应用程序需要在没有互联网连接的情况下运行
Diagram showing the PyTorch logo representing a PyTorch model, fanning out to icons for web, mobile and browser devices running ONNX Runtime

在本文中,我们将揭秘在边缘设备上运行 PyTorch 模型。我们将“边缘”定义为云端之外的任何地方,从大型、资源充足的个人计算机到小型设备(如手机)。过去完成这项任务一直具有挑战性,但模型优化和软件(如 ONNX Runtime)的新进展使其变得更加可行 - 即使对于新的生成式 AI 和大型语言模型(如 Stable Diffusion、Whisper 和 Llama2)也是如此。

在边缘设备上运行 PyTorch 模型的注意事项

在考虑在边缘设备上运行 PyTorch 模型时,需要记住几个因素:

  • 大小:现代模型可能达到几个 GB(因此得名大型语言模型!)。在云端,大小通常不是一个考虑因素,除非它变得太大而无法安装在单个 GPU 上。届时,将有各种众所周知的解决方案用于跨多个 GPU 运行。对于边缘设备,我们需要找到可以适应设备约束的模型。这有时需要与质量进行权衡。大多数现代模型都有几种尺寸(10 亿参数、130 亿参数、700 亿参数等),因此您可以选择适合您设备的变体。通常应用量化等技术来减少表示参数的位数,从而进一步减小模型大小。应用程序的大小也受到应用商店的限制,因此在边缘设备上引入数 GB 的库是行不通的。
  • 应用程序集成的 API:在云端,模型通常打包为 Docker 容器,这些容器公开一个由应用程序或服务调用的端点。在边缘设备上,Docker 容器可能占用过多资源,甚至可能不受支持。通过使用优化的引擎(如 ONNX Runtime),可以消除对 Python 和 Docker 容器的依赖。ONNX Runtime 还提供多种语言的 API,包括 C、C++、C#、Rust、Java、JavaScript、Objective-C 和 Swift,以便更轻松地与宿主应用程序进行本地集成。
  • 性能:凭借大量的内存、无功耗限制和强大的计算能力,在云端运行未经优化的模型是可行的。在边缘设备上,这些奢侈条件不存在,优化至关重要。例如,ONNX Runtime 优化了内存分配、融合了模型运算符、减少了内核启动时间、最大限度地减少了处理单元之间的张量传输,并应用了优化的矩阵数学算法。它还能够利用设备特定的编译器和引擎,为您的应用程序提供通用接口,同时利用每种设备上的最佳方法。
  • 可维护性:在云端,更新模型就像部署新的容器镜像并增加流量一样简单。在边缘设备上,您需要考虑如何分发模型更新。有时这涉及到向应用商店发布更新,有时可能可以在您的应用程序中实施数据更新机制并下载新的模型文件,甚至可能是增量更新。有很多可能的途径,因此我们不会在本文中深入探讨这个主题,但这是您在规划生产时需要记住的一个方面。
  • 混合:您可以选择同时使用云端和设备,而不是在云端与设备之间进行选择。如今,Office 等应用程序在生产中使用了几种混合模式。一种模式是根据网络条件或输入特征动态决定是在设备上还是在云端运行。另一种模式是在设备上运行模型管道的一部分,在云端运行一部分。这对于具有单独的编码器和解码器阶段的现代模型管道尤其有用。使用像 ONNX Runtime 这样在云端和设备上都适用的引擎可以简化开发。我们将在即将发表的文章中更详细地讨论混合场景。
  • 个性化:在许多情况下,PyTorch 模型只是在设备上运行。但是,您也可能遇到需要在设备上个性化模型而无需将数据发送到云端的场景。推荐和内容定向是可以通过根据设备上的活动更新模型来提高其质量的示例场景。在设备上使用 PyTorch 进行微调和训练可能不可行(由于性能和大小方面的考虑),但使用像 ONNX Runtime 这样的引擎可以本地更新和个性化 PyTorch 模型。相同的机制也支持联邦学习,这有助于减少用户数据泄露。

用于边缘设备上 PyTorch 模型的工具

上面我们多次提到了 ONNX Runtime。ONNX Runtime 是一种紧凑的、基于标准的引擎,它与 PyTorch 深度集成。通过使用 PyTorch 的 ONNX API,您的 PyTorch 模型可以在各种边缘设备上使用 ONNX Runtime 运行。

在边缘设备上运行 PyTorch 模型的第一步是将其转换为轻量级格式,这种格式不需要 PyTorch 框架及其数 GB 的依赖项。PyTorch 已经考虑到了这一点,并包含了一个 API 来实现这一点 - torch.onnxONNX 是一种开放标准,它定义了构成模型的运算符。PyTorch ONNX API 将 Pythonic PyTorch 代码转换为功能图,该功能图捕获运行模型而无需 Python 所需的运算符。与机器学习中的所有事物一样,有一些限制需要注意。某些 PyTorch 模型无法表示为单个图 - 在这种情况下,您可能需要输出多个图并在您自己的管道中将它们拼接在一起。

流行的 Hugging Face 库也具有构建在 torch.onnx 功能之上的 API,用于将模型导出为 ONNX 格式。超过 130,000 个模型 受到支持,这使得您关心的模型很可能是其中之一。

在本文中,我们将向您展示几个示例,涉及最先进的 PyTorch 模型(如 Whisper 和 Stable Diffusion)在流行的设备(如 Windows 笔记本电脑、手机和 Web 浏览器)上通过各种语言(从 C# 到 JavaScript 再到 Swift)运行。

边缘设备上 PyTorch 模型的示例

Windows 上的 Stable Diffusion

Stable Diffusion 管道由五个 PyTorch 模型组成,这些模型从文本描述构建图像。扩散过程在随机像素上迭代,直到输出图像与描述匹配。

为了在边缘设备上运行,可以将其中四个模型从 HuggingFace 导出为 ONNX 格式。

from optimum.onnxruntime import ORTStableDiffusionPipeline
pipeline = ORTStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", export=True)
pipeline.save_pretrained("./onnx-stable-diffusion")

您不必导出第五个模型 ClipTokenizer,因为它在 ONNX Runtime 扩展(用于 PyTorch 模型预处理和后处理的库)中可用。

为了将这个模型管道作为 .NET 应用程序运行,我们在 C# 中构建管道代码。如果您的机器上可用,则可以使用 ONNX Runtime 的设备特定硬件加速器在 CPU、GPU 或 NPU 上运行此代码。这通过下面的 `ExecutionProviderTarget` 配置。

static void Main(string[] args)
{
    var prompt = "Two golden retriever puppies playing in the grass.";
    var config = new StableDiffusionConfig
    {
        NumInferenceSteps = 50,
        GuidanceScale = 7.5,
        ExecutionProviderTarget = StableDiffusionConfig.ExecutionProvider.Cpu,
        DeviceId = 0,
        TokenizerOnnxPath = ".\models\tokenizer\model.onnx",
        TextEncoderOnnxPath = ".\models\text_encoder\model.onnx",
        UnetOnnxPath = ".\models\unet\model.onnx",
        VaeDecoderOnnxPath = ".\models\vae_decoder\model.onnx",
        SafetyModelPath = ".\models\safety_checker\model.onnx",
    };

    var image = UNet.Inference(prompt, config);

    if (image == null)
    {
        Console.WriteLine("Unable to create image, please try again.");
    }
}

这是模型管道的输出,运行了 50 次推理迭代

Two golden retriever puppies playing in the grass

您可以使用此教程中显示的详细步骤在 Windows 上构建和运行应用程序。

浏览器中的文本生成

使用 transformers.js 库,在浏览器中本地运行 PyTorch 模型不仅是可能的,而且非常简单。Transformers.js 使用 ONNX Runtime Web 作为其后端。许多模型已经转换为 ONNX 并由 tranformers.js CDN 提供服务,这使得在浏览器中进行推理只需编写几行 HTML 代码即可。

<html>
    <body>
        <h1>Enter starting text …</h1>
    
        <form id="form">
            <input type="text" id="inputText">
            <button type="submit" id="submitButton">Submit</button>
        </form>
    
        <div id="output"></div>
    
        <script type="module">
    
            import { pipeline } from 'https://cdn.jsdelivr.net.cn/npm/@xenova/transformers@2.6.2';
    
            let inputText = document.getElementById('inputText');
            let outputDiv = document.getElementById('output');
            let submitButton = document.getElementById('submitButton');
    
            submitButton.addEventListener('click', async (e) => {
    
                e.preventDefault();
    
                let generator = await pipeline('text-generation', 'Xenova/LaMini-Neo-125M');
    
                let result = await generator(inputText.value,
                    { max_new_tokens: 200,
                        temperature: 2,
                        repetition_penalty: 1.5,
                        no_repeat_ngram_size: 2,
                        num_beams: 2,
                        num_return_sequences: 1,
                        });
    
                outputDiv.innerHTML = result[0].generated_text;
            });
        </script>    
    </body>
</html>

您还可以使用纯 JavaScript 或在 Web 应用程序中使用 React 或 Next.js 嵌入对 transformers 管道的调用,或者编写浏览器扩展。

ONNX Runtime Web 目前使用 Web Assembly 在 CPU 上执行模型。这对于许多模型来说已经足够好了,但如果设备上存在 GPU,则利用 GPU 可以改善用户体验。ONNX Runtime Web 对 WebGPU 的支持即将推出,使您可以在使用相同推理 API 的同时利用 GPU。

Text generation in the browser using transformers.js. The prompt is Two golden retriever puppies are playing in the grass, and the response is playing in the grasslands. They are known for their playful nature and they have a playful face.

在移动设备上使用 Whisper 进行语音识别

来自 OpenAI 的 Whisper 是一个 PyTorch 语音识别模型。Whisper 有多种不同大小的变体 - 最小的 Whisper Tiny 适合在移动设备上运行。Whisper Tiny 模型的所有组件(音频解码器、编码器、解码器和文本序列生成)都可以使用 Olive 框架组合并导出到单个 ONNX 模型。为了将此模型作为移动应用程序的一部分运行,您可以使用 ONNX Runtime Mobile,它支持 Android、iOS、react-native 和 MAUI/Xamarin。

ONNX Runtime Mobile 通过 NNAPI(在 Android 上)、CoreML(在 iOS 上)和 XNNPACK(在 iOS 和 Android 上)支持硬件加速。

下面显示了一个示例 Android 移动应用程序的相关代码片段,该应用程序对音频的短样本执行语音转录

init {
    val env = OrtEnvironment.getEnvironment()
    val sessionOptions = OrtSession.SessionOptions()
    sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())

    session = env.createSession(modelBytes, sessionOptions)

    val nMels: Long = 80
    val nFrames: Long = 3000

    baseInputs = mapOf(
        "min_length" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "max_length" to createIntTensor(env, intArrayOf(200), tensorShape(1)),
        "num_beams" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "num_return_sequences" to createIntTensor(env, intArrayOf(1), tensorShape(1)),
        "length_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
        "repetition_penalty" to createFloatTensor(env, floatArrayOf(1.0f), tensorShape(1)),
    )
}

data class Result(val text: String, val inferenceTimeInMs: Long)

fun run(audioTensor: OnnxTensor): Result {
    val inputs = mutableMapOf()
    baseInputs.toMap(inputs)
    inputs["audio_pcm"] = audioTensor
    val startTimeInMs = SystemClock.elapsedRealtime()
    val outputs = session.run(inputs)
    val elapsedTimeInMs = SystemClock.elapsedRealtime() - startTimeInMs
    val recognizedText = outputs.use {
        @Suppress("UNCHECKED_CAST")
        (outputs[0].value as Array>)[0][0]
    }
    return Result(recognizedText, elapsedTimeInMs)
}

您可以录制一个短音频片段进行转录。

Screenshot of an Android app to perform speech recognition using ONNX Runtime, running a PyTorch Whisper model

在移动设备上训练模型以识别您的声音

ONNX Runtime 还可以获取预训练模型并使其适应新数据。它可以在边缘设备上执行此操作 - 特别是在移动设备上,在移动设备上,可以轻松录制您的声音、访问您的照片和其他个性化数据。重要的是,您的数据在训练期间不会离开设备。

例如,您可以训练 PyTorch 模型以仅在您的手机上识别您自己的声音,用于身份验证场景。

PyTorch 模型是从您的开发环境中的 HuggingFace 获得的,并添加了额外的层来执行说话人分类

from transformers import Wav2Vec2ForSequenceClassification, AutoConfig
                    import torch

                    config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid")
                    model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid")

                    model.classifier = torch.nn.Linear(256, 2)

模型和训练所需的其他组件(用于衡量模型质量的损失函数和指示如何在训练期间调整权重的优化器)使用 ONNX Runtime Training 导出

artifacts.generate_artifacts(
    onnx_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    loss=CustomCELoss(),
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="MyVoice/artifacts",
)

这组工件现在可以由移动应用程序加载,此处显示为 iOS Swift 代码。该应用程序要求用户提供他们的声音样本,并使用这些样本训练模型。

func trainStep(inputData: [Data], labels: [Int64]) throws  {

    let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)]
    try trainingSession.trainStep(withInputValues: inputs)
    
    try trainingSession.optimizerStep()
    
    try trainingSession.lazyResetGrad()
}

模型训练完成后,您可以运行它来验证声音样本是否是您的声音!

A screenshot of an iPhone app to perform speaker verification by recording a number of speech samples of the speaker

您可以阅读完整的说话人验证教程,并从源代码构建和运行应用程序

下一步去哪里?

在本文中,我们展示了为什么要在边缘设备上运行 PyTorch 模型以及需要考虑哪些方面。我们还分享了几个示例,其中包含您可以用来在边缘设备上使用 ONNX Runtime 运行最先进的 PyTorch 模型的代码。我们还展示了 ONNX Runtime 如何为性能和跨平台执行而构建,使其成为在边缘设备上运行 PyTorch 模型的理想方式。祝您使用 ONNX Runtime 在边缘设备上运行 PyTorch 模型玩得开心!

您可能已经注意到,即使 ONNX Runtime 经过优化可以运行 Llama2,但我们没有包含 Llama2 示例。那是因为令人惊叹的 Llama2 模型值得单独写一篇文章,敬请期待!