PyTorch 模型推理
了解 PyTorch 以及如何使用 PyTorch 模型进行推理。
PyTorch 以其易于理解且灵活的 API,以及大量现成的模型(特别是在自然语言处理 (NLP) 领域)和领域特定库,引领着深度学习领域。
越来越多的开发者和应用程序希望使用使用 PyTorch 构建的模型,本文提供了 PyTorch 模型推理的快速导览。有许多不同的方法可以对 PyTorch 模型进行推理;这些方法在下面列出。
本文假设您正在寻找使用 PyTorch 模型进行推理的信息,而不是如何训练 PyTorch 模型。
目录
PyTorch 概述
PyTorch 的核心是 nn.Module
,它是一个表示整个深度学习模型或单个层的类。可以通过组合或扩展模块来构建模型。要编写自己的模块,需要实现一个 forward 函数,该函数根据模型输入和模型的训练权重计算输出。如果您正在编写自己的 PyTorch 模型,那么您也很可能正在训练它。另外,您也可以使用 PyTorch 本身或其他库(例如 HuggingFace)中的预训练模型。
使用 PyTorch 本身编写图像处理模型
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet18, ResNet18_Weights
class Predictor(nn.Module):
def __init__(self):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
y_pred = self.resnet18(x)
return y_pred.argmax(dim=1)
使用 HuggingFace 库创建语言模型时,您可以
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
model = transformers.BertForQuestionAnswering.from_pretrained(model_name)
创建或导入训练好的模型后,如何运行它进行推理?下面我们介绍几种可以在 PyTorch 中进行推理的方法。
推理选项
使用原生 PyTorch 进行推理
如果您对性能或大小不敏感,并且运行环境包含 Python 可执行文件和库,则可以在原生 PyTorch 中运行应用程序。
获得训练好的模型后,您(或您的数据科学团队)可以使用两种方法来保存和加载模型进行推理
-
保存和加载整个模型
# Save the entire model to PATH torch.save(model, PATH) # Load the model from PATH and set eval mode for inference model = torch.load(PATH) model.eval()
-
保存模型参数,重新声明模型,并加载参数
# Save the model parameters torch.save(model.state_dict(), PATH) # Redeclare the model and load the saved parameters model = TheModel(...) model.load_state_dict(torch.load(PATH)) model.eval()
使用哪种方法取决于您的配置。保存和加载整个模型意味着您无需重新声明模型,甚至无需访问模型代码本身。但缺点是保存环境和加载环境必须在可用类、方法和参数方面完全匹配(因为它们是直接序列化和反序列化的)。
只要您可以访问原始模型代码,保存模型的训练参数(状态字典,即 state_dict)比第一种方法更灵活。
有两个主要原因您可能不想使用原生 PyTorch 对模型进行推理。第一个是您必须在包含 Python 运行时以及 PyTorch 库和相关依赖项的环境中运行——这些文件加起来有几个 GB。如果想在移动手机、Web 浏览器或专用硬件等环境中运行,使用原生 PyTorch 进行 PyTorch 推理将无法工作。第二个原因是性能。开箱即用的 PyTorch 模型可能无法提供您的应用程序所需的性能。
使用 TorchScript 进行推理
如果您在无法安装 PyTorch 或其他 Python 库的受限环境中运行,则可以选择使用已转换为 TorchScript 的 PyTorch 模型进行推理。TorchScript 是 Python 的一个子集,允许您创建可序列化的模型,这些模型可以在非 Python 环境中加载和执行。
# Export to TorchScript
script = torch.jit.script(model, example)
# Save scripted model
script.save(PATH)
# Load scripted model
model = torch.jit.load(PATH)
model.eval()
#include <torch/script.h>
...
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule
module = torch::jit::load(PATH);
}
catch (const c10::Error& e) {
...
}
...
虽然您不需要在环境中有 Python 运行时即可使用 TorchScript 方法对 PyTorch 模型进行推理,但确实需要安装 libtorch 二进制文件,这些文件可能对于您的环境来说太大。您的应用程序也可能无法获得所需的性能。
使用 ONNXRuntime 进行推理
当性能和可移植性至关重要时,您可以使用 ONNXRuntime 对 PyTorch 模型进行推理。使用 ONNXRuntime,可以降低延迟和内存消耗,提高吞吐量。还可以使用 ONNXRuntime 提供的语言绑定和库,在云端、边缘设备、Web 或移动设备上运行模型。
第一步是使用 PyTorch ONNX 导出器将您的 PyTorch 模型导出为 ONNX 格式。
# Specify example data
example = ...
# Export model to ONNX format
torch.onnx.export(model, PATH, example)
导出为 ONNX 格式后,您可以选择在 Netron 查看器中查看模型,以了解模型图、输入输出节点名称和形状,以及哪些节点具有可变大小的输入和输出(动态轴)。
然后您可以在您选择的环境中运行 ONNX 模型。ONNXRuntime 引擎由 C++ 实现,并提供 C++、Python、C#、Java、Javascript、Julia 和 Ruby 等语言的 API。ONNXRuntime 可以在 Linux、Mac、Windows、iOS 和 Android 上运行您的模型。例如,以下代码片段展示了一个 C++ 推理应用程序的骨架。
// Allocate ONNXRuntime session
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Env env;
Ort::Session session{env, ORT_TSTR("model.onnx"), Ort::SessionOptions{nullptr}};
// Allocate model inputs: fill in shape and size
std::array<float, ...> input{};
std::array<int64_t, ...> input_shape{...};
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(), input_shape.data(), input_shape.size());
const char* input_names[] = {...};
// Allocate model outputs: fill in shape and size
std::array<float, ...> output{};
std::array<int64_t, ...> output_shape{...};
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output.data(), output.size(), output_shape.data(), output_shape.size());
const char* output_names[] = {...};
// Run the model
session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
开箱即用,ONNXRuntime 会对 ONNX 图应用一系列优化,尽可能合并节点并提取常量值(常量折叠)。ONNXRuntime 还通过其执行提供程序 (Execution Provider) 接口集成了多种硬件加速器,包括 CUDA、TensorRT、OpenVINO、CoreML 和 NNAPI,具体取决于您面向的硬件平台。
您可以通过量化 ONNX 模型来进一步提高其性能。
如果应用程序在受限环境(例如移动和边缘设备)中运行,您可以根据应用程序运行的一个或一组模型构建一个精简的运行时。
要开始使用您选择的语言和环境,请参阅ONNX Runtime 入门