ORTModule 大型模型训练入门
ONNX Runtime Training
的 ORTModule
为使用 PyTorch
前端定义的模型提供高性能训练引擎。ORTModule
旨在加速大型模型的训练,无需更改模型定义,只需对整个训练脚本进行单行代码更改(ORTModule
包装器)。
使用 ORTModule 类包装器,ONNX Runtime 使用优化的自动导出的 ONNX 计算图运行训练脚本的前向和后向传递。
ORT 训练示例
在此示例中,我们将介绍如何使用 ORT 和 PyTorch 训练模型。
# Installs the torch_ort and onnxruntime-training Python packages
pip install torch-ort
# Configures onnxruntime-training to work with user's PyTorch installation
python -m torch_ort.configure
注意: 这会安装默认版本的 torch-ort
和 onnxruntime-training
包,这些包映射到特定版本的 CUDA 库。请参阅 onnxruntime.ai 中的安装选项。
- 在
train.py
中添加 ORTModule
+ from torch_ort import ORTModule
.
.
.
- model = build_model() # Users PyTorch model
+ model = ORTModule(build_model())