在设备上训练模型#

生成训练工件后,可以使用 onnxruntime 训练 Python API 在设备上训练模型。

预期的训练工件包括

  1. 训练 ONNX 模型

  2. 检查点状态

  3. 优化器 ONNX 模型

  4. 评估 ONNX 模型(可选)

示例用法

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)

# Create the module
module = Module(path_to_the_training_model,
                state,
                path_to_the_eval_model,
                device="cpu")

optimizer = Optimizer(path_to_the_optimizer_model, module)

# Training loop
for ...:
    module.train()
    training_loss = module(...)
    optimizer.step()
    module.lazy_reset_grad()

# Eval
module.eval()
eval_loss = module(...)

# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#

基类:object

表示模型参数的类

此类表示模型参数,并提供对其数据、梯度及其他属性的访问。此类不应直接实例化。相反,它由 CheckpointState 对象返回。

参数:
  • parameter – 持有底层参数数据的 C.Parameter 对象。

  • state – 持有底层会话状态的 C.CheckpointState 对象。

property name: str#

参数的名称

property data: ndarray#

参数的数据

property grad: ndarray#

参数的梯度

property requires_grad: bool#

参数是否需要计算其梯度

__repr__() str[source]#

返回参数的字符串表示

class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#

基类:object

包含所有模型参数的类

此类包含所有模型参数并提供对其的访问。此类不应直接实例化。相反,它由 CheckpointState 的 parameters 属性返回。此类的行为类似于字典,并按名称提供对参数的访问。

参数:

state – 持有底层会话状态的 C.CheckpointState 对象。

__getitem__(name: str) Parameter[source]#

获取与给定名称关联的参数

在检查点状态的参数中搜索该名称。

参数:

name – 参数的名称

返回:

参数的值

抛出:

KeyError – 如果未找到参数

__setitem__(name: str, value: ndarray) None[source]#

设置给定名称的参数值

在检查点状态的参数中搜索该名称。如果找到该名称,则更新其值。

参数:
  • name – 参数的名称

  • value – 作为 numpy 数组的参数值

抛出:

KeyError – 如果未找到参数

__contains__(name: str) bool[source]#

检查参数是否存在于状态中

参数:

name – 参数的名称

返回:

如果名称是参数,则为 True,否则为 False

__iter__()[source]#

返回属性的迭代器

__repr__() str[source]#

返回参数的字符串表示

__len__() int[source]#

返回参数的数量

class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#

基类:object

__getitem__(name: str) int | float | str[source]#

获取与给定名称关联的属性

在检查点状态的属性中搜索该名称。

参数:

name – 属性的名称

返回:

属性的值

抛出:

KeyError – 如果未找到属性

__setitem__(name: str, value: int | float | str) None[source]#

设置给定名称的属性值

在检查点状态的属性中搜索该名称。该值将添加到属性中或在属性中更新。

参数:
  • name – 属性的名称

  • value – 属性值。属性仅支持 int、float 和 str 值。

__contains__(name: str) bool[source]#

检查属性是否存在于状态中

参数:

name – 属性的名称

返回:

如果名称是属性,则为 True,否则为 False

__iter__()[source]#

返回属性的迭代器

__repr__() str[source]#

返回属性的字符串表示

__len__() int[source]#

返回属性的数量

class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#

基类:object

包含训练会话状态的类

此类包含训练会话的所有状态信息,例如模型参数、其梯度、优化器状态和用户定义的属性。

要创建 CheckpointState,请使用 CheckpointState.load_checkpoint 方法。

参数:

state – 包含底层会话状态的 C.Checkpoint state 对象。

classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState[source]#

从检查点文件加载检查点状态

检查点文件可以是完整的检查点或名义检查点。

参数:

checkpoint_uri – 检查点文件的路径。

返回:

检查点状态对象。

返回类型:

CheckpointState

classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None[source]#

将检查点状态保存到检查点文件

参数:
  • state – 检查点状态对象。

  • checkpoint_uri – 检查点文件的路径。

  • include_optimizer_state – 如果为 True,优化器状态也将保存到检查点文件。

property parameters: Parameters#

从检查点状态返回模型参数

property properties: Properties#

从检查点状态返回属性

class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#

基类:object

提供 ONNX 模型训练和评估方法的训练器类。

在实例化 Module 类之前,应已使用 onnxruntime.training.artifacts.generate_artifacts 工具生成了训练工件。

训练工件包括
  • 训练模型

  • 评估模型(可选)

  • 优化器模型(可选)

  • 检查点文件

training#

如果模型处于训练模式则为 True,如果处于评估模式则为 False。

类型:

bool

参数:
  • train_model_uri – 训练模型的路径。

  • state – 检查点状态对象。

  • eval_model_uri – 评估模型的路径。

  • device – 运行模型的设备。默认为“cpu”。

  • session_options – 模型使用的会话选项。

__call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue[source]#

调用模型的训练或评估步骤。

参数:

*user_inputs – 模型的输入。用户输入可以是 numpy 数组或 OrtValue。

返回:

模型的输出。

train(mode: bool = True) Module[source]#

将模块设置为训练模式。

参数:

mode – 是否将模型设置为训练模式 (True) 或评估模式 (False)。默认值:True。

返回:

self

eval() Module[source]#

将模块设置为评估模式。

返回:

self

lazy_reset_grad()[source]#

惰性重置训练梯度。

此函数设置模块的内部状态,以便模块梯度将在下次调用 train() 计算新梯度之前被调度重置。

get_contiguous_parameters(trainable_only: bool = False) OrtValue[source]#

创建训练会话参数的连续缓冲区

参数:

trainable_only – 如果为 True,则只考虑可训练参数。否则,考虑所有参数。

返回:

训练会话参数的连续缓冲区。

get_parameters_size(trainable_only: bool = True) int[source]#

返回参数的大小。

参数:

trainable_only – 如果为 True,则只考虑可训练参数。否则,考虑所有参数。

返回:

参数中原始(例如浮点)元素的数量。

copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None[source]#

将 OrtValue 缓冲区复制到训练会话参数。

如果模块是从名义检查点加载的,则需要调用此函数将更新的参数加载到检查点以完成它。

参数:

buffer – 要复制到训练会话参数的 OrtValue 缓冲区。

export_model_for_inferencing(inference_model_uri: str | os.PathLike, graph_output_names: list[str]) None[source]#

导出模型用于推理。

训练完成后,此函数可用于删除 ONNX 模型中训练特定的节点。具体来说,此函数执行以下操作

  • 解析训练图并识别生成给定输出名称的节点。

  • 删除图中所有后续节点,因为它们与推理图无关。

参数:
  • inference_model_uri – 推理模型的路径。

  • graph_output_names – 推理所需的输出名称列表。

input_names() list[str][source]#

返回训练模型或评估模型的输入名称。

output_names() list[str][source]#

返回训练模型或评估模型的输出名称。

class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#

基类:object

提供根据计算出的梯度更新模型参数方法的类。

参数:
  • optimizer_uri – 优化器模型的路径。

  • model – 要训练的模块。

step() None[source]#

根据计算出的梯度更新模型参数。

此方法通过在计算梯度的方向上迈出一步来更新模型参数。所使用的优化器取决于所提供的优化器模型。

set_learning_rate(learning_rate: float) None[source]#

设置优化器的学习率。

参数:

learning_rate – 要设置的学习率。

get_learning_rate() float[source]#

获取优化器当前的学习率。

返回:

当前的学习率。

class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#

基类:object

线性更新优化器中的学习率

线性学习率调度器通过线性更新的乘法因子将训练会话中设置的初始学习率衰减到 0。衰减在初始预热阶段之后执行,在该阶段中学习率从 0 线性增加到提供的初始学习率。

参数:
  • optimizer – 用户的 onnxruntime 训练优化器

  • warmup_step_count – 预热阶段的步数。

  • total_step_count – 总训练步数。

  • initial_lr – 初始学习率。

step() None[source]#

线性更新优化器的学习率。

在训练的每一步都应调用此方法,以确保正确调整学习率。