概述#
onnxruntime-training 的 ORTModule 为使用 PyTorch 前端定义的模型提供了高性能训练引擎。ORTModule 旨在加速大型模型的训练,而无需更改模型定义或训练代码。
ORTModule 的目标是为用户 PyTorch 程序中的一个或多个 torch.nn.Module 对象提供即插即用的替代方案,并使用 ORT 执行这些模块的前向和后向传递。
因此,用户将能够使用 ORT 加速他们的训练脚本,而无需修改他们的训练循环。
用户将能够使用标准的 PyTorch 调试技术来解决收敛问题,例如,通过探测模型参数上计算出的梯度。
以下代码示例说明了如何在用户的训练脚本中使用 ORTModule,在整个模型可以卸载到 ONNX Runtime 的简单情况下。
from onnxruntime.training import ORTModule
# Original PyTorch model
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model) # The only change to the original PyTorch script
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Training Loop is unchanged
for data, target in data_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()