ORTTrainingSession
Objective-C
@interface ORTTrainingSession : NSObject
Swift
class ORTTrainingSession : NSObject
提供用于训练、评估和优化 ONNX 模型的方法的训练器类。
训练会话需要四个训练工件
- 训练 onnx 模型
- 评估 onnx 模型 (可选)
- 优化器 onnx 模型
- 检查点目录
onnxruntime-training python 工具可用于生成上述训练工件。
自 1.16 版本起可用。
注意
此类仅在启用训练 API 时可用。-
不可用
声明
Objective-C
- (instancetype)init NS_UNAVAILABLE;
-
根据可用于开始或恢复训练的训练工件创建训练会话。
此初始化程序根据提供的 env 和 session options 实例化训练会话,可用于从给定的检查点状态开始或恢复训练。检查点状态代表训练会话的参数,这些参数在需要时将被移动到 session option 中指定的设备。
注意
请注意,使用检查点状态创建的训练会话使用此状态存储整个训练状态(包括模型参数、其梯度、优化器状态和属性)。训练会话对检查点状态保持强(拥有)指针。
声明
Objective-C
- (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env sessionOptions: (nullable ORTSessionOptions *)sessionOptions checkpoint:(nonnull ORTCheckpoint *)checkpoint trainModelPath:(nonnull NSString *)trainModelPath evalModelPath:(nullable NSString *)evalModelPath optimizerModelPath:(nullable NSString *)optimizerModelPath error:(NSError *_Nullable *_Nullable)error;
Swift
init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws
参数
env
用于训练会话的
ORTEnv
实例。sessionOptions
用于训练会话的可选
ORTSessionOptions
。checkpoint
用作训练起点的训练状态。
trainModelPath
训练 onnx 模型的路径。
evalModelPath
评估 onnx 模型的路径。
optimizerModelPath
用于执行梯度下降的优化器 onnx 模型的路径。
error
发生错误时设置的可选错误信息。
返回值
实例,如果发生错误则为 nil。
-
执行一个训练步骤,这相当于在一个步骤中完成前向和后向传播。
训练步骤计算训练模型的输出以及给定输入值的可训练参数的梯度。训练步骤是基于提供给训练会话的训练模型执行的。这相当于在一个步骤中运行前向和后向传播。计算出的梯度存储在训练会话状态中,以便后续可被
optimizerStep
消耗。可以通过调用lazyResetGrad
方法来延迟重置梯度。声明
参数
inputs
训练模型的输入值。
error
发生错误时设置的可选错误信息。
返回值
训练模型的输出值。
-
延迟地将所有可训练参数的梯度重置为零。
调用此方法会设置训练会话的内部状态,使得 ORTCheckpoint 中可训练参数的梯度将在下一次
trainStep
方法调用中计算新梯度之前被安排重置。声明
Objective-C
- (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;
Swift
func lazyResetGrad() throws
参数
error
发生错误时设置的可选错误信息。
返回值
如果梯度成功设置为重置,则为 YES,否则为 NO。
-
使用优化器模型对可训练参数执行权重更新。优化器步骤是基于提供给训练会话的优化器模型执行的。更新后的参数存储在训练状态中,以便下次
trainStep
方法调用时使用。声明
Objective-C
- (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;
Swift
func optimizerStep() throws
参数
error
发生错误时设置的可选错误信息。
返回值
如果优化器步骤成功执行,则为 YES,否则为 NO。
-
返回可与提供给
trainStep
的ORTValue
相关联的训练模型的用户输入名称。声明
Objective-C
- (nullable NSArray<NSString *> *)getTrainInputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getTrainInputNames() throws -> [String]
参数
error
发生错误时设置的可选错误信息。
返回值
训练模型的用户输入名称。
-
返回可与提供给
evalStep
的ORTValue
相关联的评估模型的用户输入名称。声明
Objective-C
- (nullable NSArray<NSString *> *)getEvalInputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getEvalInputNames() throws -> [String]
参数
error
发生错误时设置的可选错误信息。
返回值
评估模型的用户输入名称。
-
返回可与
trainStep
返回的ORTValue
相关联的训练模型的用户输出名称。声明
Objective-C
- (nullable NSArray<NSString *> *)getTrainOutputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getTrainOutputNames() throws -> [String]
参数
error
发生错误时设置的可选错误信息。
返回值
训练模型的用户输出名称。
-
返回可与
evalStep
返回的ORTValue
相关联的评估模型的用户输出名称。声明
Objective-C
- (nullable NSArray<NSString *> *)getEvalOutputNamesWithError: (NSError *_Nullable *_Nullable)error;
Swift
func getEvalOutputNames() throws -> [String]
参数
error
发生错误时设置的可选错误信息。
返回值
评估模型的用户输出名称。
-
为训练会话注册一个线性学习率调度器。
调度器在训练过程中将学习率从初始值逐渐降低到零。降低通过将当前学习率乘以线性更新因子来实现。在降低之前,学习率在预热阶段从零逐渐增加到初始值。
声明
Objective-C
- (BOOL) registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount totalStepCount:(int64_t)totalStepCount initialLr:(float)initialLr error:(NSError *_Nullable *_Nullable) error;
Swift
func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws
参数
warmupStepCount
执行线性预热的步数。
totalStepCount
执行线性衰减的总步数。
initialLr
初始学习率。
error
发生错误时设置的可选错误信息。
返回值
如果调度器成功注册,则为 YES,否则为 NO。
-
根据已注册的学习率调度器更新学习率。
执行一个调度器步骤,更新训练会话正在使用的学习率。通常应在每轮调用优化器步骤之前或在需要更新训练会话正在使用的学习率时调用此函数。
注意
必须先注册一个有效的预定义学习率调度器才能调用此方法。
声明
Objective-C
- (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;
Swift
func schedulerStep() throws
参数
error
发生错误时设置的可选错误信息。
返回值
如果调度器步骤成功执行,则为 YES,否则为 NO。
-
返回训练会话当前使用的学习率。
声明
Objective-C
- (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;
Swift
func getLearningRate() throws -> Float
参数
error
发生错误时设置的可选错误信息。
返回值
当前学习率,如果发生错误则为 0.0f。
-
设置训练会话正在使用的学习率。
当前学习率由训练会话维护,并可通过调用此方法并提供所需学习率来覆盖。当已注册有效的学习率调度器时,不应使用此函数。它应仅用于设置来自自定义学习率调度器的学习率,或设置在整个训练会话期间使用的恒定学习率。
注意
它不会设置预定义学习率调度器可能需要的初始学习率。要为学习率调度器设置初始学习率,请使用
registerLinearLRScheduler
方法。声明
Objective-C
- (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;
Swift
func setLearningRate(_ lr: Float) throws
参数
lr
训练会话将使用的学习率。
error
发生错误时设置的可选错误信息。
返回值
如果学习率成功设置,则为 YES,否则为 NO。
-
返回一个连续缓冲区,其中包含所有训练状态参数的副本。
声明
Objective-C
- (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable error: (NSError *_Nullable *_Nullable)error;
Swift
func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue
参数
onlyTrainable
如果为 YES,则返回一个仅包含可训练参数的缓冲区,否则返回包含所有参数的缓冲区。
error
发生错误时设置的可选错误信息。
返回值
一个包含所有训练状态参数副本的连续缓冲区。
-
导出可用于推理的训练会话模型。
如果训练会话提供了评估模型,则训练会话可以在知道推理图输出的情况下生成推理模型。输入的推理图输出用于修剪评估模型,以便推理模型的输出与提供的输出对齐。导出的模型保存在提供的路径中,并可与
ORTSession
一起用于推理。注意
此方法从提供给初始化程序的路径重新加载评估模型,并期望此路径有效。
声明
Objective-C
- (BOOL) exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath graphOutputNames: (nonnull NSArray<NSString *> *)graphOutputNames error:(NSError *_Nullable *_Nullable)error;
Swift
func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws
参数
inferenceModelPath
序列化推理模型的路径。
graphOutputNames
推理模型中所需的输出名称。
error
发生错误时设置的可选错误信息。
返回值
如果推理模型成功导出,则为 YES,否则为 NO。