使用 ORTModule 开始大型模型训练
ONNX Runtime Training
的 ORTModule
为使用 PyTorch
前端定义的模型提供了高性能训练引擎。ORTModule
的设计旨在加速大型模型的训练,而无需更改模型定义,并且只需对整个训练脚本进行一行代码更改(即 ORTModule
封装)。
使用 ORTModule 类封装器,ONNX Runtime 利用经过优化的自动导出的 ONNX 计算图来运行训练脚本的正向和反向传播。
ORT 训练示例
在本示例中,我们将介绍如何使用 ORT 训练 PyTorch 模型。
# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure
注意:这将安装 torch-ort
和 onnxruntime-training
包的默认版本,这些版本映射到特定版本的 CUDA 库。请参考 onnxruntime.ai 中的安装选项。
- 在
train.py
中添加 ORTModule
+ from torch_ort import ORTModule
.
.
.
- model = build_model() # Users PyTorch model
+ model = ORTModule(build_model())