ORTTrainingSession

Objective-C

@interface ORTTrainingSession : NSObject

Swift

class ORTTrainingSession : NSObject

提供用于训练、评估和优化 ONNX 模型的方法的训练器类。

训练会话需要四个训练工件

  1. 训练 onnx 模型
  2. 评估 onnx 模型 (可选)
  3. 优化器 onnx 模型
  4. 检查点目录

onnxruntime-training python 工具可用于生成上述训练工件。

自 1.16 版本起可用。

注意

此类仅在启用训练 API 时可用。
  • 不可用

    声明

    Objective-C

    - (instancetype)init NS_UNAVAILABLE;
  • 根据可用于开始或恢复训练的训练工件创建训练会话。

    此初始化程序根据提供的 env 和 session options 实例化训练会话,可用于从给定的检查点状态开始或恢复训练。检查点状态代表训练会话的参数,这些参数在需要时将被移动到 session option 中指定的设备。

    注意

    请注意,使用检查点状态创建的训练会话使用此状态存储整个训练状态(包括模型参数、其梯度、优化器状态和属性)。训练会话对检查点状态保持强(拥有)指针。

    声明

    Objective-C

    - (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env
                          sessionOptions:
                              (nullable ORTSessionOptions *)sessionOptions
                              checkpoint:(nonnull ORTCheckpoint *)checkpoint
                          trainModelPath:(nonnull NSString *)trainModelPath
                           evalModelPath:(nullable NSString *)evalModelPath
                      optimizerModelPath:(nullable NSString *)optimizerModelPath
                                   error:(NSError *_Nullable *_Nullable)error;

    Swift

    init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws

    参数

    env

    用于训练会话的 ORTEnv 实例。

    sessionOptions

    用于训练会话的可选 ORTSessionOptions

    checkpoint

    用作训练起点的训练状态。

    trainModelPath

    训练 onnx 模型的路径。

    evalModelPath

    评估 onnx 模型的路径。

    optimizerModelPath

    用于执行梯度下降的优化器 onnx 模型的路径。

    error

    发生错误时设置的可选错误信息。

    返回值

    实例,如果发生错误则为 nil。

  • 执行一个训练步骤,这相当于在一个步骤中完成前向和后向传播。

    训练步骤计算训练模型的输出以及给定输入值的可训练参数的梯度。训练步骤是基于提供给训练会话的训练模型执行的。这相当于在一个步骤中运行前向和后向传播。计算出的梯度存储在训练会话状态中,以便后续可被 optimizerStep 消耗。可以通过调用 lazyResetGrad 方法来延迟重置梯度。

    声明

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        trainStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                           error:(NSError *_Nullable *_Nullable)error;

    Swift

    func trainStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    参数

    inputs

    训练模型的输入值。

    error

    发生错误时设置的可选错误信息。

    返回值

    训练模型的输出值。

  • 执行一个评估步骤,该步骤计算给定输入的评估模型的输出。评估步骤是基于提供给训练会话的评估模型执行的。

    声明

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        evalStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func evalStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    参数

    inputs

    评估模型的输入值。

    error

    发生错误时设置的可选错误信息。

    返回值

    评估模型的输出值。

  • 延迟地将所有可训练参数的梯度重置为零。

    调用此方法会设置训练会话的内部状态,使得 ORTCheckpoint 中可训练参数的梯度将在下一次 trainStep 方法调用中计算新梯度之前被安排重置。

    声明

    Objective-C

    - (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func lazyResetGrad() throws

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    如果梯度成功设置为重置,则为 YES,否则为 NO。

  • 使用优化器模型对可训练参数执行权重更新。优化器步骤是基于提供给训练会话的优化器模型执行的。更新后的参数存储在训练状态中,以便下次 trainStep 方法调用时使用。

    声明

    Objective-C

    - (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func optimizerStep() throws

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    如果优化器步骤成功执行,则为 YES,否则为 NO。

  • 返回可与提供给 trainStepORTValue 相关联的训练模型的用户输入名称。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainInputNames() throws -> [String]

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    训练模型的用户输入名称。

  • 返回可与提供给 evalStepORTValue 相关联的评估模型的用户输入名称。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalInputNames() throws -> [String]

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    评估模型的用户输入名称。

  • 返回可与 trainStep 返回的 ORTValue 相关联的训练模型的用户输出名称。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainOutputNames() throws -> [String]

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    训练模型的用户输出名称。

  • 返回可与 evalStep 返回的 ORTValue 相关联的评估模型的用户输出名称。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalOutputNames() throws -> [String]

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    评估模型的用户输出名称。

  • 为训练会话注册一个线性学习率调度器。

    调度器在训练过程中将学习率从初始值逐渐降低到零。降低通过将当前学习率乘以线性更新因子来实现。在降低之前,学习率在预热阶段从零逐渐增加到初始值。

    声明

    Objective-C

    - (BOOL)
        registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
                                      totalStepCount:(int64_t)totalStepCount
                                           initialLr:(float)initialLr
                                               error:(NSError *_Nullable *_Nullable)
                                                         error;

    Swift

    func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws

    参数

    warmupStepCount

    执行线性预热的步数。

    totalStepCount

    执行线性衰减的总步数。

    initialLr

    初始学习率。

    error

    发生错误时设置的可选错误信息。

    返回值

    如果调度器成功注册,则为 YES,否则为 NO。

  • 根据已注册的学习率调度器更新学习率。

    执行一个调度器步骤,更新训练会话正在使用的学习率。通常应在每轮调用优化器步骤之前或在需要更新训练会话正在使用的学习率时调用此函数。

    注意

    必须先注册一个有效的预定义学习率调度器才能调用此方法。

    声明

    Objective-C

    - (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func schedulerStep() throws

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    如果调度器步骤成功执行,则为 YES,否则为 NO。

  • 返回训练会话当前使用的学习率。

    声明

    Objective-C

    - (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func getLearningRate() throws -> Float

    参数

    error

    发生错误时设置的可选错误信息。

    返回值

    当前学习率,如果发生错误则为 0.0f。

  • 设置训练会话正在使用的学习率。

    当前学习率由训练会话维护,并可通过调用此方法并提供所需学习率来覆盖。当已注册有效的学习率调度器时,不应使用此函数。它应仅用于设置来自自定义学习率调度器的学习率,或设置在整个训练会话期间使用的恒定学习率。

    注意

    它不会设置预定义学习率调度器可能需要的初始学习率。要为学习率调度器设置初始学习率,请使用 registerLinearLRScheduler 方法。

    声明

    Objective-C

    - (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;

    Swift

    func setLearningRate(_ lr: Float) throws

    参数

    lr

    训练会话将使用的学习率。

    error

    发生错误时设置的可选错误信息。

    返回值

    如果学习率成功设置,则为 YES,否则为 NO。

  • 从连续缓冲区加载训练会话模型参数。

    声明

    Objective-C

    - (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func fromBuffer(with buffer: ORTValue) throws

    参数

    buffer

    从中加载参数的连续缓冲区。

    error

    发生错误时设置的可选错误信息。

    返回值

    如果参数成功加载,则为 YES,否则为 NO。

  • 返回一个连续缓冲区,其中包含所有训练状态参数的副本。

    声明

    Objective-C

    - (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable
                                           error:
                                               (NSError *_Nullable *_Nullable)error;

    Swift

    func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue

    参数

    onlyTrainable

    如果为 YES,则返回一个仅包含可训练参数的缓冲区,否则返回包含所有参数的缓冲区。

    error

    发生错误时设置的可选错误信息。

    返回值

    一个包含所有训练状态参数副本的连续缓冲区。

  • 导出可用于推理的训练会话模型。

    如果训练会话提供了评估模型,则训练会话可以在知道推理图输出的情况下生成推理模型。输入的推理图输出用于修剪评估模型,以便推理模型的输出与提供的输出对齐。导出的模型保存在提供的路径中,并可与 ORTSession 一起用于推理。

    注意

    此方法从提供给初始化程序的路径重新加载评估模型,并期望此路径有效。

    声明

    Objective-C

    - (BOOL)
        exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath
                             graphOutputNames:
                                 (nonnull NSArray<NSString *> *)graphOutputNames
                                        error:(NSError *_Nullable *_Nullable)error;

    Swift

    func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws

    参数

    inferenceModelPath

    序列化推理模型的路径。

    graphOutputNames

    推理模型中所需的输出名称。

    error

    发生错误时设置的可选错误信息。

    返回值

    如果推理模型成功导出,则为 YES,否则为 NO。