使用 ORTModule 开始大型模型训练

ONNX Runtime TrainingORTModule 为使用 PyTorch 前端定义的模型提供了高性能训练引擎。ORTModule 的设计旨在加速大型模型的训练,而无需更改模型定义,并且只需对整个训练脚本进行一行代码更改(即 ORTModule 封装)。

使用 ORTModule 类封装器,ONNX Runtime 利用经过优化的自动导出的 ONNX 计算图来运行训练脚本的正向和反向传播。

ORT 训练示例

在本示例中,我们将介绍如何使用 ORT 训练 PyTorch 模型。

# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure

注意:这将安装 torch-ortonnxruntime-training 包的默认版本,这些版本映射到特定版本的 CUDA 库。请参考 onnxruntime.ai 中的安装选项。

  • train.py 中添加 ORTModule
+  from torch_ort import ORTModule
   .
   .
   .
-  model = build_model() # Users PyTorch model
+  model = ORTModule(build_model())

示例

ONNX Runtime 训练示例