推理 PyTorch 模型
了解 PyTorch 以及如何使用 PyTorch 模型进行推理。
PyTorch 以其易于理解和灵活的 API;大量可用的现成模型,特别是在自然语言处理 (NLP) 领域;以及其领域特定的库,引领着深度学习领域。
越来越多的开发者和应用程序希望使用 PyTorch 构建的模型,本文快速介绍了 PyTorch 模型的推理。PyTorch 模型有多种不同的推理方式;下面将列举这些方式。
本文假定您正在寻找有关如何使用 PyTorch 模型进行推理的信息,而不是如何训练 PyTorch 模型。
目录
PyTorch 概述
PyTorch 的核心是 nn.Module
,它是一个代表整个深度学习模型或单个层的类。模块可以组合或扩展以构建模型。要编写自己的模块,您需要实现一个根据模型输入和模型训练权重计算输出的前向函数。如果您正在编写自己的 PyTorch 模型,那么您很可能也在训练它。或者,您可以使用 PyTorch 本身或其他库(例如 HuggingFace)的预训练模型。
使用 PyTorch 本身编写图像处理模型
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet18, ResNet18_Weights
class Predictor(nn.Module):
def __init__(self):
super().__init__()
weights = ResNet18_Weights.DEFAULT
self.resnet18 = resnet18(weights=weights, progress=False).eval()
self.transforms = weights.transforms()
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
y_pred = self.resnet18(x)
return y_pred.argmax(dim=1)
要使用 HuggingFace 库创建语言模型,您可以
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
model = transformers.BertForQuestionAnswering.from_pretrained(model_name)
创建或导入训练好的模型后,如何运行它进行推理?下面我们介绍几种您可以在 PyTorch 中进行推理的方法。
推理选项
使用原生 PyTorch 进行推理
如果您对性能或大小不敏感,并且在包含 Python 可执行文件和库的环境中运行,您可以在原生 PyTorch 中运行您的应用程序。
获得训练好的模型后,您(或您的数据科学团队)可以使用两种方法来保存和加载模型以进行推理:
-
保存和加载整个模型
# Save the entire model to PATH torch.save(model, PATH) # Load the model from PATH and set eval mode for inference model = torch.load(PATH) model.eval()
-
保存模型参数,重新声明模型,然后加载参数
# Save the model parameters torch.save(model.state_dict(), PATH) # Redeclare the model and load the saved parameters model = TheModel(...) model.load_state_dict(torch.load(PATH)) model.eval()
您使用哪种方法取决于您的配置。保存和加载整个模型意味着您无需重新声明模型,甚至无需访问模型代码本身。但缺点是,保存环境和加载环境必须在可用的类、方法和参数方面匹配(因为这些是直接序列化和反序列化的)。
只要您可以访问原始模型代码,保存模型训练好的参数(状态字典或 state_dict)比第一种方法更灵活。
您可能不想使用原生 PyTorch 对模型进行推理的主要原因有两个。首先,您必须在包含 Python 运行时、PyTorch 库和相关依赖项的环境中运行——这些文件加起来有几千兆字节。如果您想在移动电话、Web 浏览器或专用硬件等环境中运行,使用原生 PyTorch 进行 PyTorch 推理将无法工作。第二个原因是性能。开箱即用的 PyTorch 模型可能无法提供您的应用程序所需的性能。
使用 TorchScript 进行推理
如果您在更受限制的环境中运行,无法安装 PyTorch 或其他 Python 库,则可以选择使用已转换为 TorchScript 的 PyTorch 模型进行推理。TorchScript 是 Python 的一个子集,它允许您创建可序列化的模型,这些模型可以在非 Python 环境中加载和执行。
# Export to TorchScript
script = torch.jit.script(model, example)
# Save scripted model
script.save(PATH)
# Load scripted model
model = torch.jit.load(PATH)
model.eval()
#include <torch/script.h>
...
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule
module = torch::jit::load(PATH);
}
catch (const c10::Error& e) {
...
}
...
虽然您不需要在环境中拥有 Python 运行时即可使用 TorchScript 方法对 PyTorch 模型进行推理,但您确实需要安装 libtorch 二进制文件,这些文件可能对于您的环境来说太大了。您也可能无法获得应用程序所需的性能。
使用 ONNXRuntime 进行推理
当性能和可移植性至关重要时,您可以使用 ONNXRuntime 对 PyTorch 模型进行推理。通过 ONNXRuntime,您可以减少延迟和内存占用,并提高吞吐量。您还可以使用 ONNXRuntime 提供的语言绑定和库,在云、边缘、Web 或移动设备上运行模型。
第一步是使用 PyTorch ONNX 导出器将您的 PyTorch 模型导出为 ONNX 格式。
# Specify example data
example = ...
# Export model to ONNX format
torch.onnx.export(model, PATH, example)
导出为 ONNX 格式后,您可以选择在 Netron 查看器中查看模型,以了解模型图、输入和输出节点名称和形状,以及哪些节点具有可变大小的输入和输出(动态轴)。
然后您可以在您选择的环境中运行 ONNX 模型。ONNXRuntime 引擎用 C++ 实现,并提供 C++、Python、C#、Java、Javascript、Julia 和 Ruby 的 API。ONNXRuntime 可以在 Linux、Mac、Windows、iOS 和 Android 上运行您的模型。例如,以下代码片段展示了一个 C++ 推理应用程序的骨架。
// Allocate ONNXRuntime session
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Env env;
Ort::Session session{env, ORT_TSTR("model.onnx"), Ort::SessionOptions{nullptr}};
// Allocate model inputs: fill in shape and size
std::array<float, ...> input{};
std::array<int64_t, ...> input_shape{...};
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input.data(), input.size(), input_shape.data(), input_shape.size());
const char* input_names[] = {...};
// Allocate model outputs: fill in shape and size
std::array<float, ...> output{};
std::array<int64_t, ...> output_shape{...};
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output.data(), output.size(), output_shape.data(), output_shape.size());
const char* output_names[] = {...};
// Run the model
session_.Run(Ort::RunOptions{nullptr}, input_names, &input_tensor, 1, output_names, &output_tensor, 1);
开箱即用,ONNXRuntime 会对 ONNX 图应用一系列优化,尽可能地组合节点并提取常量值(常量折叠)。ONNXRuntime 还通过其执行提供程序接口与多种硬件加速器集成,包括 CUDA、TensorRT、OpenVINO、CoreML 和 NNAPI,具体取决于您所针对的硬件平台。
您可以通过量化 ONNX 模型来进一步提高其性能。
如果应用程序在移动和边缘等受限环境中运行,您可以根据应用程序运行的模型或模型集构建一个精简版运行时。
要在您选择的语言和环境中开始使用,请参阅ONNX Runtime 快速入门