在设备上训练模型#
生成训练工件后,可以使用 onnxruntime 训练 Python API 在设备上训练模型。
预期的训练工件包括
训练 ONNX 模型
检查点状态
优化器 ONNX 模型
评估 ONNX 模型(可选)
示例用法
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#
基类:
object
表示模型参数的类
此类表示模型参数,并提供对其数据、梯度及其他属性的访问。此类不应直接实例化。相反,它由 CheckpointState 对象返回。
- 参数:
parameter – 持有底层参数数据的 C.Parameter 对象。
state – 持有底层会话状态的 C.CheckpointState 对象。
- class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#
基类:
object
包含所有模型参数的类
此类包含所有模型参数并提供对其的访问。此类不应直接实例化。相反,它由 CheckpointState 的 parameters 属性返回。此类的行为类似于字典,并按名称提供对参数的访问。
- 参数:
state – 持有底层会话状态的 C.CheckpointState 对象。
- __getitem__(name: str) Parameter [source]#
获取与给定名称关联的参数
在检查点状态的参数中搜索该名称。
- 参数:
name – 参数的名称
- 返回:
参数的值
- 抛出:
KeyError – 如果未找到参数
- class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#
基类:
object
- __getitem__(name: str) int | float | str [source]#
获取与给定名称关联的属性
在检查点状态的属性中搜索该名称。
- 参数:
name – 属性的名称
- 返回:
属性的值
- 抛出:
KeyError – 如果未找到属性
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#
基类:
object
包含训练会话状态的类
此类包含训练会话的所有状态信息,例如模型参数、其梯度、优化器状态和用户定义的属性。
要创建 CheckpointState,请使用 CheckpointState.load_checkpoint 方法。
- 参数:
state – 包含底层会话状态的 C.Checkpoint state 对象。
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState [source]#
从检查点文件加载检查点状态
检查点文件可以是完整的检查点或名义检查点。
- 参数:
checkpoint_uri – 检查点文件的路径。
- 返回:
检查点状态对象。
- 返回类型:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None [source]#
将检查点状态保存到检查点文件
- 参数:
state – 检查点状态对象。
checkpoint_uri – 检查点文件的路径。
include_optimizer_state – 如果为 True,优化器状态也将保存到检查点文件。
- property parameters: Parameters#
从检查点状态返回模型参数
- property properties: Properties#
从检查点状态返回属性
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#
基类:
object
提供 ONNX 模型训练和评估方法的训练器类。
在实例化 Module 类之前,应已使用 onnxruntime.training.artifacts.generate_artifacts 工具生成了训练工件。
- 训练工件包括
训练模型
评估模型(可选)
优化器模型(可选)
检查点文件
- 参数:
train_model_uri – 训练模型的路径。
state – 检查点状态对象。
eval_model_uri – 评估模型的路径。
device – 运行模型的设备。默认为“cpu”。
session_options – 模型使用的会话选项。
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue [source]#
调用模型的训练或评估步骤。
- 参数:
*user_inputs – 模型的输入。用户输入可以是 numpy 数组或 OrtValue。
- 返回:
模型的输出。
- train(mode: bool = True) Module [source]#
将模块设置为训练模式。
- 参数:
mode – 是否将模型设置为训练模式 (True) 或评估模式 (False)。默认值:True。
- 返回:
self
- get_contiguous_parameters(trainable_only: bool = False) OrtValue [source]#
创建训练会话参数的连续缓冲区
- 参数:
trainable_only – 如果为 True,则只考虑可训练参数。否则,考虑所有参数。
- 返回:
训练会话参数的连续缓冲区。
- get_parameters_size(trainable_only: bool = True) int [source]#
返回参数的大小。
- 参数:
trainable_only – 如果为 True,则只考虑可训练参数。否则,考虑所有参数。
- 返回:
参数中原始(例如浮点)元素的数量。
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None [source]#
将 OrtValue 缓冲区复制到训练会话参数。
如果模块是从名义检查点加载的,则需要调用此函数将更新的参数加载到检查点以完成它。
- 参数:
buffer – 要复制到训练会话参数的 OrtValue 缓冲区。
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
基类:
object
提供根据计算出的梯度更新模型参数方法的类。
- 参数:
optimizer_uri – 优化器模型的路径。
model – 要训练的模块。
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
基类:
object
线性更新优化器中的学习率
线性学习率调度器通过线性更新的乘法因子将训练会话中设置的初始学习率衰减到 0。衰减在初始预热阶段之后执行,在该阶段中学习率从 0 线性增加到提供的初始学习率。
- 参数:
optimizer – 用户的 onnxruntime 训练优化器
warmup_step_count – 预热阶段的步数。
total_step_count – 总训练步数。
initial_lr – 初始学习率。