概述#

onnxruntime-trainingORTModule 为使用 PyTorch 前端定义的模型提供了高性能训练引擎。ORTModule 旨在加速大型模型的训练,而无需更改模型定义或训练代码。

ORTModule 的目标是为用户 PyTorch 程序中的一个或多个 torch.nn.Module 对象提供即插即用的替代方案,并使用 ORT 执行这些模块的前向和后向传递。

因此,用户将能够使用 ORT 加速他们的训练脚本,而无需修改他们的训练循环。

用户将能够使用标准的 PyTorch 调试技术来解决收敛问题,例如,通过探测模型参数上计算出的梯度。

以下代码示例说明了如何在用户的训练脚本中使用 ORTModule,在整个模型可以卸载到 ONNX Runtime 的简单情况下。

from onnxruntime.training import ORTModule

# Original PyTorch model
class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        ...
    def forward(self, x):
        ...

model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
model = ORTModule(model) # The only change to the original PyTorch script
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

# Training Loop is unchanged
for data, target in data_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()