ORTTrainingSession

Objective-C

@interface ORTTrainingSession : NSObject

Swift

class ORTTrainingSession : NSObject

训练器类,提供用于训练、评估和优化 ONNX 模型的方法。

训练会话需要四种训练工件

  1. 训练 ONNX 模型
  2. 评估 ONNX 模型(可选)
  3. 优化器 ONNX 模型
  4. 检查点目录

onnxruntime-training python 工具可用于生成上述训练工件。

自 1.16 版本可用。

注意

此类别仅在启用训练 API 时可用。
  • 不可用

    声明

    Objective-C

    - (instancetype)init NS_UNAVAILABLE;
  • 从训练工件创建训练会话,可用于开始或恢复训练。

    此初始化器根据提供的环境和会话选项实例化训练会话,可用于从给定的检查点状态开始或恢复训练。检查点状态表示训练会话的参数,如果需要,这些参数将被移动到会话选项中指定的设备。

    注意

    请注意,使用检查点状态创建的训练会话将此状态用于存储整个训练状态(包括模型参数、其梯度、优化器状态和属性)。训练会话会持有检查点状态的强(拥有)指针。

    声明

    Objective-C

    - (nullable instancetype)initWithEnv:(nonnull ORTEnv *)env
                          sessionOptions:
                              (nullable ORTSessionOptions *)sessionOptions
                              checkpoint:(nonnull ORTCheckpoint *)checkpoint
                          trainModelPath:(nonnull NSString *)trainModelPath
                           evalModelPath:(nullable NSString *)evalModelPath
                      optimizerModelPath:(nullable NSString *)optimizerModelPath
                                   error:(NSError *_Nullable *_Nullable)error;

    Swift

    init(env: ORTEnv, sessionOptions: ORTSessionOptions?, checkpoint: ORTCheckpoint, trainModelPath: String, evalModelPath: String?, optimizerModelPath: String?) throws

    参数

    env

    用于训练会话的 ORTEnv 实例。

    sessionOptions

    用于训练会话的可选 ORTSessionOptions

    checkpoint

    用作训练起点的训练状态。

    trainModelPath

    训练 ONNX 模型的路径。

    evalModelPath

    评估 ONNX 模型的路径。

    optimizerModelPath

    用于执行梯度下降的优化器 ONNX 模型的路径。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    实例,如果发生错误则为 nil。

  • 执行一个训练步骤,相当于一个步骤中的前向和后向传播。

    训练步骤计算训练模型的输出和给定输入值下可训练参数的梯度。训练步骤是根据提供给训练会话的训练模型执行的。它等同于在一个步骤中运行前向和后向传播。计算出的梯度存储在训练会话状态中,以便后续可由 optimizerStep 消耗。可以通过调用 lazyResetGrad 方法延迟重置梯度。

    声明

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        trainStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                           error:(NSError *_Nullable *_Nullable)error;

    Swift

    func trainStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    参数

    inputs

    训练模型的输入值。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    训练模型的输出值。

  • 执行一个评估步骤,计算给定输入下评估模型的输出。评估步骤是根据提供给训练会话的评估模型执行的。

    声明

    Objective-C

    - (nullable NSArray<ORTValue *> *)
        evalStepWithInputValues:(nonnull NSArray<ORTValue *> *)inputs
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func evalStep(withInputValues inputs: [ORTValue]) throws -> [ORTValue]

    参数

    inputs

    评估模型的输入值。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    评估模型的输出值。

  • 延迟将所有可训练参数的梯度重置为零。

    调用此方法会设置训练会话的内部状态,以便在下次调用 trainStep 方法计算新梯度之前,将 ORTCheckpoint 中可训练参数的梯度安排重置。

    声明

    Objective-C

    - (BOOL)lazyResetGradWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func lazyResetGrad() throws

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果梯度成功重置则为 YES,否则为 NO。

  • 使用优化器模型对可训练参数执行权重更新。优化器步骤是根据提供给训练会话的优化器模型执行的。更新后的参数存储在训练状态中,以便下次调用 trainStep 方法时使用。

    声明

    Objective-C

    - (BOOL)optimizerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func optimizerStep() throws

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果优化器步骤成功执行则为 YES,否则为 NO。

  • 返回训练模型的用户输入名称,这些名称可与提供给 trainStepORTValue 相关联。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainInputNames() throws -> [String]

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    训练模型的用户输入名称。

  • 返回评估模型的用户输入名称,这些名称可与提供给 evalStepORTValue 相关联。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalInputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalInputNames() throws -> [String]

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    评估模型的用户输入名称。

  • 返回训练模型的用户输出名称,这些名称可与 trainStep 返回的 ORTValue 相关联。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getTrainOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getTrainOutputNames() throws -> [String]

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    训练模型的用户输出名称。

  • 返回评估模型的用户输出名称,这些名称可与 evalStep 返回的 ORTValue 相关联。

    声明

    Objective-C

    - (nullable NSArray<NSString *> *)getEvalOutputNamesWithError:
        (NSError *_Nullable *_Nullable)error;

    Swift

    func getEvalOutputNames() throws -> [String]

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    评估模型的用户输出名称。

  • 为训练会话注册一个线性学习率调度器。

    调度器在训练过程中将学习率从初始值逐渐降低到零。降低是通过将当前学习率乘以一个线性更新因子来执行的。在降低之前,学习率在预热阶段从零逐渐增加到初始值。

    声明

    Objective-C

    - (BOOL)
        registerLinearLRSchedulerWithWarmupStepCount:(int64_t)warmupStepCount
                                      totalStepCount:(int64_t)totalStepCount
                                           initialLr:(float)initialLr
                                               error:(NSError *_Nullable *_Nullable)
                                                         error;

    Swift

    func registerLinearLRScheduler(withWarmupStepCount warmupStepCount: Int64, totalStepCount: Int64, initialLr: Float) throws

    参数

    warmupStepCount

    执行线性预热的步数。

    totalStepCount

    执行线性衰减的总步数。

    initialLr

    初始学习率。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果调度器成功注册则为 YES,否则为 NO。

  • 根据已注册的学习率调度器更新学习率。

    执行一个调度器步骤,更新训练会话正在使用的学习率。此函数通常应在每轮调用优化器步骤之前调用,或根据需要更新训练会话正在使用的学习率。

    注意

    必须首先注册一个有效的预定义学习率调度器才能调用此方法。

    声明

    Objective-C

    - (BOOL)schedulerStepWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func schedulerStep() throws

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果调度器步骤成功执行则为 YES,否则为 NO。

  • 返回训练会话当前使用的学习率。

    声明

    Objective-C

    - (float)getLearningRateWithError:(NSError *_Nullable *_Nullable)error;

    Swift

    func getLearningRate() throws -> Float

    参数

    error

    如果发生错误,设置可选的错误信息。

    返回值

    当前学习率,如果发生错误则为 0.0f。

  • 设置训练会话正在使用的学习率。

    当前学习率由训练会话维护,并可通过调用此方法并传入所需学习率来覆盖。当注册了有效的学习率调度器时,不应使用此函数。它应仅用于设置自定义学习率调度器派生的学习率,或设置在整个训练会话中使用的恒定学习率。

    注意

    它不设置预定义学习率调度器可能需要的初始学习率。要为学习率调度器设置初始学习率,请使用 registerLinearLRScheduler 方法。

    声明

    Objective-C

    - (BOOL)setLearningRate:(float)lr error:(NSError *_Nullable *_Nullable)error;

    Swift

    func setLearningRate(_ lr: Float) throws

    参数

    lr

    训练会话将使用的学习率。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果学习率成功设置则为 YES,否则为 NO。

  • 从连续缓冲区加载训练会话模型参数。

    声明

    Objective-C

    - (BOOL)fromBufferWithValue:(nonnull ORTValue *)buffer
                          error:(NSError *_Nullable *_Nullable)error;

    Swift

    func fromBuffer(with buffer: ORTValue) throws

    参数

    buffer

    用于加载参数的连续缓冲区。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果参数成功加载则为 YES,否则为 NO。

  • 返回一个包含所有训练状态参数副本的连续缓冲区。

    声明

    Objective-C

    - (nullable ORTValue *)toBufferWithTrainable:(BOOL)onlyTrainable
                                           error:
                                               (NSError *_Nullable *_Nullable)error;

    Swift

    func toBuffer(withTrainable onlyTrainable: Bool) throws -> ORTValue

    参数

    onlyTrainable

    如果为 YES,则返回一个仅包含可训练参数的缓冲区;否则,返回一个包含所有参数的缓冲区。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    包含所有训练状态参数副本的连续缓冲区。

  • 导出可用于推理的训练会话模型。

    如果训练会话提供了评估模型,并且已知推理图输出,则训练会话可以生成推理模型。输入的推理图输出用于修剪评估模型,以便推理模型的输出与提供的输出对齐。导出的模型保存在提供的路径中,并可与 ORTSession 一起用于推理。

    注意

    此方法从提供给初始化器的路径重新加载评估模型,并要求此路径有效。

    声明

    Objective-C

    - (BOOL)
        exportModelForInferenceWithOutputPath:(nonnull NSString *)inferenceModelPath
                             graphOutputNames:
                                 (nonnull NSArray<NSString *> *)graphOutputNames
                                        error:(NSError *_Nullable *_Nullable)error;

    Swift

    func exportModelForInference(withOutputPath inferenceModelPath: String, graphOutputNames: [String]) throws

    参数

    inferenceModelPath

    推理模型的序列化路径。

    graphOutputNames

    推理模型中所需的输出名称。

    error

    如果发生错误,设置可选的错误信息。

    返回值

    如果推理模型成功导出则为 YES,否则为 NO。