ORTModule 大型模型训练入门

ONNX Runtime TrainingORTModule 为使用 PyTorch 前端定义的模型提供高性能训练引擎。ORTModule 旨在加速大型模型的训练,无需更改模型定义,只需对整个训练脚本进行单行代码更改(ORTModule 包装器)。

使用 ORTModule 类包装器,ONNX Runtime 使用优化的自动导出的 ONNX 计算图运行训练脚本的前向和后向传递。

ORT 训练示例

在此示例中,我们将介绍如何使用 ORT 和 PyTorch 训练模型。

# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure

注意: 这会安装默认版本的 torch-ortonnxruntime-training 包,这些包映射到特定版本的 CUDA 库。请参阅 onnxruntime.ai 中的安装选项。

  • train.py 中添加 ORTModule
+  from torch_ort import ORTModule
   .
   .
   .
-  model = build_model() # Users PyTorch model
+  model = ORTModule(build_model())

示例

ONNX Runtime 训练示例