使用 ORTModule 开始大型模型训练
ONNX Runtime Training
的 ORTModule
为使用 PyTorch
前端定义的模型提供高性能训练引擎。ORTModule
旨在加速大型模型的训练,无需更改模型定义,只需对整个训练脚本进行一行代码更改(即 ORTModule
封装)。
使用 ORTModule 类封装器,ONNX Runtime 通过优化过的自动导出的 ONNX 计算图运行训练脚本的前向和后向传播。
ORT 训练示例
在此示例中,我们将介绍如何使用 ORT 训练 PyTorch 模型。
# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure
注意: 这会安装默认版本的 torch-ort
和 onnxruntime-training
包,这些包映射到特定版本的 CUDA 库。请参阅 onnxruntime.ai 中的安装选项。
- 在
train.py
中添加 ORTModule
+ from torch_ort import ORTModule
.
.
.
- model = build_model() # Users PyTorch model
+ model = ORTModule(build_model())