在设备上训练模型#
一旦生成训练工件,就可以使用 onnxruntime training python API 在设备上训练模型。
预期的训练工件是
训练 onnx 模型
检查点状态
优化器 onnx 模型
评估 onnx 模型 (可选)
示例用法
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load the checkpoint state
state = CheckpointState.load_checkpoint(path_to_the_checkpoint_artifact)
# Create the module
module = Module(path_to_the_training_model,
state,
path_to_the_eval_model,
device="cpu")
optimizer = Optimizer(path_to_the_optimizer_model, module)
# Training loop
for ...:
module.train()
training_loss = module(...)
optimizer.step()
module.lazy_reset_grad()
# Eval
module.eval()
eval_loss = module(...)
# Save the checkpoint
CheckpointState.save_checkpoint(state, path_to_the_checkpoint_artifact)
- class onnxruntime.training.api.checkpoint_state.Parameter(parameter: Parameter, state: CheckpointState)[source]#
基类:
object
表示模型参数的类
此类表示模型参数,并提供对其数据、梯度和其他属性的访问。预计不会直接实例化此类。相反,它由 CheckpointState 对象返回。
- 参数:
parameter – 保存底层参数数据的 C.Parameter 对象。
state – 保存底层会话状态的 C.CheckpointState 对象。
- class onnxruntime.training.api.checkpoint_state.Parameters(state: CheckpointState)[source]#
基类:
object
保存所有模型参数的类
此类保存所有模型参数,并提供对它们的访问。预计不会直接实例化此类。相反,它由 CheckpointState 的 parameters 属性返回。此类行为类似于字典,并提供按名称访问参数的方式。
- 参数:
state – 保存底层会话状态的 C.CheckpointState 对象。
- __getitem__(name: str) Parameter [source]#
获取与给定名称关联的参数
在检查点状态的参数中搜索名称。
- 参数:
name – 参数的名称
- 返回:
参数的值
- 引发:
KeyError – 如果未找到参数
- class onnxruntime.training.api.checkpoint_state.Properties(state: CheckpointState)[source]#
基类:
object
- __getitem__(name: str) int | float | str [source]#
获取与给定名称关联的属性
在检查点状态的属性中搜索名称。
- 参数:
name – 属性的名称
- 返回:
属性的值
- 引发:
KeyError – 如果未找到属性
- class onnxruntime.training.api.CheckpointState(state: CheckpointState)[source]#
基类:
object
保存训练会话状态的类
此类保存训练会话的所有状态信息,例如模型参数、其梯度、优化器状态和用户定义的属性。
要创建 CheckpointState,请使用 CheckpointState.load_checkpoint 方法。
- 参数:
state – 保存底层会话状态的 C.Checkpoint 状态对象。
- classmethod load_checkpoint(checkpoint_uri: str | os.PathLike) CheckpointState [source]#
从检查点文件加载检查点状态
检查点文件可以是完整检查点或名义检查点。
- 参数:
checkpoint_uri – 检查点文件的路径。
- 返回:
检查点状态对象。
- 返回类型:
- classmethod save_checkpoint(state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False) None [source]#
将检查点状态保存到检查点文件
- 参数:
state – 检查点状态对象。
checkpoint_uri – 检查点文件的路径。
include_optimizer_state – 如果为 True,优化器状态也会保存到检查点文件。
- property parameters: Parameters#
从检查点状态返回模型参数
- property properties: Properties#
从检查点状态返回属性
- class onnxruntime.training.api.Module(train_model_uri: PathLike, state: CheckpointState, eval_model_uri: Optional[PathLike] = None, device: str = 'cpu', session_options: Optional[SessionOptions] = None)[source]#
基类:
object
Trainer 类,为 ONNX 模型提供训练和评估方法。
在实例化 Module 类之前,预计已使用 onnxruntime.training.artifacts.generate_artifacts 实用程序生成了训练工件。
- 训练工件包括
训练模型
评估模型(可选)
优化器模型(可选)
检查点文件
- 参数:
train_model_uri – 训练模型的路径。
state – 检查点状态对象。
eval_model_uri – 评估模型的路径。
device – 运行模型的设备。默认为 “cpu”。
session_options – 用于模型的会话选项。
- __call__(*user_inputs) tuple[numpy.ndarray, ...] | numpy.ndarray | tuple[onnxruntime.capi.onnxruntime_inference_collection.OrtValue, ...] | onnxruntime.capi.onnxruntime_inference_collection.OrtValue [source]#
调用模型的训练或评估步骤。
- 参数:
*user_inputs – 模型的输入。用户输入可以是 numpy 数组或 OrtValue。
- 返回:
模型的输出。
- train(mode: bool = True) Module [source]#
将 Module 设置为训练模式。
- 参数:
mode – 是否将模型设置为训练模式 (True) 或评估模式 (False)。默认值: True。
- 返回:
self
- get_contiguous_parameters(trainable_only: bool = False) OrtValue [source]#
创建训练会话参数的连续缓冲区
- 参数:
trainable_only – 如果为 True,则仅考虑可训练参数。否则,将考虑所有参数。
- 返回:
训练会话参数的连续缓冲区。
- get_parameters_size(trainable_only: bool = True) int [source]#
返回参数的大小。
- 参数:
trainable_only – 如果为 True,则仅考虑可训练参数。否则,将考虑所有参数。
- 返回:
参数中原始元素(例如浮点)的数量。
- copy_buffer_to_parameters(buffer: OrtValue, trainable_only: bool = True) None [source]#
将 OrtValue 缓冲区复制到训练会话参数。
如果模块是从名义检查点加载的,则需要调用此函数以将更新的参数加载到检查点以完成它。
- 参数:
buffer – 要复制到训练会话参数的 OrtValue 缓冲区。
- class onnxruntime.training.api.Optimizer(optimizer_uri: str | os.PathLike, module: Module)[source]#
基类:
object
提供基于计算出的梯度更新模型参数的方法的类。
- 参数:
optimizer_uri – 优化器模型的路径。
model – 要训练的模块。
- class onnxruntime.training.api.LinearLRScheduler(optimizer: Optimizer, warmup_step_count: int, total_step_count: int, initial_lr: float)[source]#
基类:
object
线性地更新优化器中的学习率
线性学习率调度器通过线性更新的乘法因子,将学习率从训练会话中设置的初始学习率衰减到0。衰减在初始预热阶段之后执行,在预热阶段中,学习率从0线性增加到提供的初始学习率。
- 参数:
optimizer – 用户的 onnxruntime 训练优化器
warmup_step_count – 预热阶段的步数。
total_step_count – 训练步骤的总数。
initial_lr – 初始学习率。