PyTorch 模型推理
了解 PyTorch 以及如何使用 PyTorch 模型执行推理。
PyTorch 以其易于理解和灵活的 API 领跑深度学习领域; 拥有大量现成的模型,尤其是在自然语言处理 (NLP) 领域; 以及其特定领域的库。
越来越多的开发者和应用程序希望使用使用 PyTorch 构建的模型,本文将快速介绍 PyTorch 模型的推理。 有多种不同的方法可以执行 PyTorch 模型的推理; 这些方法在下面列出。
本文假设您正在寻找有关使用 PyTorch 模型执行推理的信息,而不是如何训练 PyTorch 模型的信息。
目录
PyTorch 概述
PyTorch 的核心是 nn.Module
,一个类,代表整个深度学习模型或单个层。 模块可以组合或扩展以构建模型。 要编写自己的模块,您需要实现一个 forward 函数,该函数根据模型输入和模型的训练权重计算输出。 如果您正在编写自己的 PyTorch 模型,那么您可能也在训练它。 或者,您可以使用来自 PyTorch 本身或来自其他库(如 HuggingFace)的预训练模型。
使用 PyTorch 本身编写图像处理模型代码
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet18, ResNet18_Weights
class Predictor(nn.Module):
def __init__(self):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
y_pred = self.resnet18(x)
return y_pred.argmax(dim=1)
要使用 HuggingFace 库创建语言模型,您可以
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
model = transformers.BertForQuestionAnswering.from_pretrained(model_name)
一旦您创建或导入了训练好的模型,您如何运行它来执行推理? 下面我们描述了几种可用于在 PyTorch 中执行推理的方法。
推理选项
使用原生 PyTorch 进行推理
如果您对性能或大小不敏感,并且在运行环境包含 Python 可执行文件和库,则可以在原生 PyTorch 中运行您的应用程序。
一旦您有了训练好的模型,您(或您的数据科学团队)可以使用两种方法来保存和加载模型以进行推理
-
保存和加载整个模型
# Save the entire model to PATH torch.save(model, PATH) # Load the model from PATH and set eval mode for inference model = torch.load(PATH) model.eval()
-
保存模型的参数,重新声明模型,并加载参数
# Save the model parameters torch.save(model.state_dict(), PATH) # Redeclare the model and load the saved parameters model = TheModel(...) model.load_state_dict(torch.load(PATH)) model.eval()
您使用哪种方法取决于您的配置。 保存和加载整个模型意味着您不必重新声明模型,甚至不必访问模型代码本身。 但缺点是,保存环境和加载环境在可用类、方法和参数方面必须匹配(因为这些是直接序列化和反序列化的)。
只要您可以访问原始模型代码,保存模型的训练参数(状态字典或 state_dict)比第一种方法更灵活。
您可能不想使用原生 PyTorch 对模型执行推理的主要原因有两个。 第一个原因是您必须在包含 Python 运行时以及 PyTorch 库和相关依赖项的环境中运行 - 这些加起来有几个 GB 的文件。 如果您想在移动电话、Web 浏览器或专用硬件等环境中运行,使用原生 PyTorch 运行 PyTorch 推理将不起作用。 第二个原因是性能。 开箱即用,PyTorch 模型可能无法提供您的应用程序所需的性能。
使用 TorchScript 进行推理
如果您在更受限制的环境中运行,无法安装 PyTorch 或其他 Python 库,您可以选择使用已转换为 TorchScript 的 PyTorch 模型执行推理。 TorchScript 是 Python 的一个子集,允许您创建可序列化的模型,这些模型可以在非 Python 环境中加载和执行。
# Export to TorchScript
script = torch.jit.script(model, example)
# Save scripted model
script.save(PATH)
# Load scripted model
model = torch.jit.load(PATH)
model.eval()
#include <torch/script.h>
...
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule
module = torch::jit::load(PATH);
}
catch (const c10::Error& e) {
...
}
...
虽然您不需要在环境中使用 Python 运行时即可使用 TorchScript 方法对 PyTorch 模型执行推理,但您确实需要安装 libtorch 二进制文件,这些文件可能对您的环境来说太大。 您也可能无法获得应用程序所需的性能。
使用 ONNXRuntime 进行推理
当性能和可移植性至关重要时,您可以使用 ONNXRuntime 对 PyTorch 模型执行推理。 使用 ONNXRuntime,您可以减少延迟和内存占用,并提高吞吐量。 您还可以使用 ONNXRuntime 提供的语言绑定和库,在云、边缘设备、Web 或移动端上运行模型。
第一步是使用 PyTorch ONNX 导出器将您的 PyTorch 模型导出为 ONNX 格式。
# Specify example data
example = ...
# Export model to ONNX format
torch.onnx.export(model, PATH, example)
导出为 ONNX 格式后,您可以选择在 Netron 查看器中查看模型,以了解模型图以及输入和输出节点名称和形状,以及哪些节点的输入和输出大小可变(动态轴)。
然后,您可以在您选择的环境中运行 ONNX 模型。 ONNXRuntime 引擎用 C++ 实现,并具有 C++、Python、C#、Java、Javascript、Julia 和 Ruby 的 API。 ONNXRuntime 可以在 Linux、Mac、Windows、iOS 和 Android 上运行您的模型。 例如,以下代码片段显示了 C++ 推理应用程序的框架。
// Allocate ONNXRuntime session
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Env env;
Ort::Session session{env, ORT_TSTR("model.onnx"), Ort::SessionOptions{nullptr}};
// Allocate model inputs: fill in shape and size
std::array<float, ...> input{};
std::array<int64_t, ...> input_shape{...};
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(), input_shape.data(), input_shape.size());
const char* input_names[] = {...};
// Allocate model outputs: fill in shape and size
std::array<float, ...> output{};
std::array<int64_t, ...> output_shape{...};
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output.data(), output.size(), output_shape.data(), output_shape.size());
const char* output_names[] = {...};
// Run the model
session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
开箱即用,ONNXRuntime 对 ONNX 图应用一系列优化,尽可能合并节点并分解出常量值(常量折叠)。 ONNXRuntime 还通过其执行提供程序接口与许多硬件加速器集成,包括 CUDA、TensorRT、OpenVINO、CoreML 和 NNAPI,具体取决于您要面向的硬件平台。
您可以进一步量化 ONNX 模型,以提高其性能。
如果应用程序在受限环境中运行,例如移动端和边缘设备,您可以基于应用程序运行的模型或模型集构建缩减大小的运行时。
要开始使用您选择的语言和环境,请参阅 ONNX Runtime 入门