在设备上训练模型#

一旦生成训练工件,就可以使用 onnxruntime training python API 在设备上训练模型。

预期的训练工件是

  1. 训练 onnx 模型

  2. 检查点状态

  3. 优化器 onnx 模型

  4. 评估 onnx 模型 (可选)

示例用法

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#

基类: object

表示模型参数的类

此类表示模型参数,并提供对其数据、梯度和其他属性的访问。预计不会直接实例化此类。相反,它由 CheckpointState 对象返回。

参数:
  • parameter – 保存底层参数数据的 C.Parameter 对象。

  • state – 保存底层会话状态的 C.CheckpointState 对象。

property name: str#

参数的名称

property data: ndarray#

参数的数据

property grad: ndarray#

参数的梯度

property requires_grad: bool#

参数是否需要计算梯度

__repr__() str[source]#

返回参数的字符串表示形式

class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#

基类: object

保存所有模型参数的类

此类保存所有模型参数,并提供对它们的访问。预计不会直接实例化此类。相反,它由 CheckpointState 的 parameters 属性返回。此类行为类似于字典,并提供按名称访问参数的方式。

参数:

state – 保存底层会话状态的 C.CheckpointState 对象。

__getitem__(name: str) Parameter[source]#

获取与给定名称关联的参数

在检查点状态的参数中搜索名称。

参数:

name – 参数的名称

返回:

参数的值

引发:

KeyError – 如果未找到参数

__setitem__(name: str, value: ndarray) None[source]#

设置给定名称的参数值

在检查点状态的参数中搜索名称。如果参数中找到该名称,则更新值。

参数:
  • name – 参数的名称

  • value – 参数的值,作为 numpy 数组

引发:

KeyError – 如果未找到参数

__contains__(name: str) bool[source]#

检查参数是否存在于状态中

参数:

name – 参数的名称

返回:

如果名称是参数,则为 True,否则为 False

__iter__()[source]#

返回属性的迭代器

__repr__() str[source]#

返回参数的字符串表示形式

__len__() int[source]#

返回参数的数量

class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#

基类: object

__getitem__(name: str) int | float | str[source]#

获取与给定名称关联的属性

在检查点状态的属性中搜索名称。

参数:

name – 属性的名称

返回:

属性的值

引发:

KeyError – 如果未找到属性

__setitem__(name: str, value: int | float | str) None[source]#

设置给定名称的属性值

在检查点状态的属性中搜索名称。该值将被添加或更新到属性中。

参数:
  • name – 属性的名称

  • value – 属性的值。Properties 仅支持 int、float 和 str 值。

__contains__(name: str) bool[source]#

检查属性是否存在于状态中

参数:

name – 属性的名称

返回:

如果名称是属性,则为 True,否则为 False

__iter__()[source]#

返回属性的迭代器

__repr__() str[source]#

返回属性的字符串表示形式

__len__() int[source]#

返回属性的数量

class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#

基类: object

保存训练会话状态的类

此类保存训练会话的所有状态信息,例如模型参数、其梯度、优化器状态和用户定义的属性。

要创建 CheckpointState,请使用 CheckpointState.load_checkpoint 方法。

参数:

state – 保存底层会话状态的 C.Checkpoint 状态对象。

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#

从检查点文件加载检查点状态

检查点文件可以是完整检查点或名义检查点。

参数:

checkpoint_uri – 检查点文件的路径。

返回:

检查点状态对象。

返回类型:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#

将检查点状态保存到检查点文件

参数:
  • state – 检查点状态对象。

  • checkpoint_uri – 检查点文件的路径。

  • include_optimizer_state – 如果为 True,优化器状态也会保存到检查点文件。

property parameters: Parameters#

从检查点状态返回模型参数

property properties: Properties#

从检查点状态返回属性

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#

基类: object

Trainer 类,为 ONNX 模型提供训练和评估方法。

在实例化 Module 类之前,预计已使用 onnxruntime.training.artifacts.generate_artifacts 实用程序生成了训练工件。

训练工件包括
  • 训练模型

  • 评估模型(可选)

  • 优化器模型(可选)

  • 检查点文件

training#

如果模型处于训练模式,则为 True;如果模型处于评估模式,则为 False。

类型:

bool

参数:
  • train_model_uri – 训练模型的路径。

  • state – 检查点状态对象。

  • eval_model_uri – 评估模型的路径。

  • device – 运行模型的设备。默认为 “cpu”。

  • session_options – 用于模型的会话选项。

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#

调用模型的训练或评估步骤。

参数:

*user_inputs – 模型的输入。用户输入可以是 numpy 数组或 OrtValue。

返回:

模型的输出。

train(mode: bool = True) Module[source]#

将 Module 设置为训练模式。

参数:

mode – 是否将模型设置为训练模式 (True) 或评估模式 (False)。默认值: True。

返回:

self

eval() Module[source]#

将 Module 设置为评估模式。

返回:

self

lazy_reset_grad()[source]#

延迟重置训练梯度。

此函数设置模块的内部状态,以便在下次调用 train() 时计算新梯度之前,计划重置模块梯度。

get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#

创建训练会话参数的连续缓冲区

参数:

trainable_only – 如果为 True,则仅考虑可训练参数。否则,将考虑所有参数。

返回:

训练会话参数的连续缓冲区。

get_parameters_size(trainable_only: bool = True) int[source]#

返回参数的大小。

参数:

trainable_only – 如果为 True,则仅考虑可训练参数。否则,将考虑所有参数。

返回:

参数中原始元素(例如浮点)的数量。

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

将 OrtValue 缓冲区复制到训练会话参数。

如果模块是从名义检查点加载的,则需要调用此函数以将更新的参数加载到检查点以完成它。

参数:

buffer – 要复制到训练会话参数的 OrtValue 缓冲区。

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

导出模型以进行推理。

训练完成后,可以使用此功能删除 onnx 模型中特定于训练的节点。特别是,此函数执行以下操作

  • 解析训练图并识别生成给定输出名称的节点。

  • 删除图中所有后续节点,因为它们与推理图无关。

参数:
  • inference_model_uri – 推理模型的路径。

  • graph_output_names – 推理所需的输出名称列表。

input_names() list[str][source]#

返回训练或评估模型的输入名称。

output_names() list[str][source]#

返回训练或评估模型的输出名称。

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

基类: object

提供基于计算出的梯度更新模型参数的方法的类。

参数:
  • optimizer_uri – 优化器模型的路径。

  • model – 要训练的模块。

step() None[source]#

基于计算出的梯度更新模型参数。

此方法通过沿计算出的梯度方向迈进一步来更新模型参数。使用的优化器取决于提供的优化器模型。

set_learning_rate(learning_rate: float) None[source]#

设置优化器的学习率。

参数:

learning_rate – 要设置的学习率。

get_learning_rate() float[source]#

获取优化器当前的学习率。

返回:

当前的学习率。

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

基类: object

线性地更新优化器中的学习率

线性学习率调度器通过线性更新的乘法因子,将学习率从训练会话中设置的初始学习率衰减到0。衰减在初始预热阶段之后执行,在预热阶段中,学习率从0线性增加到提供的初始学习率。

参数:
  • optimizer – 用户的 onnxruntime 训练优化器

  • warmup_step_count – 预热阶段的步数。

  • total_step_count – 训练步骤的总数。

  • initial_lr – 初始学习率。

step() None[source]#

线性地更新优化器的学习率。

应该在训练的每个步骤中调用此方法,以确保正确调整学习率。