类 OrtTrainingSession

  • 所有已实现的接口
    java.lang.AutoCloseable

    public final class OrtTrainingSession
    extends java.lang.Object
    implements java.lang.AutoCloseable
    封装 ONNX 训练模型并允许训练和推理调用。

    允许检查模型的输入和输出节点。由 OrtEnvironment 生成。

    如果会话已关闭并调用了方法,大多数实例方法会抛出 IllegalStateException

    • 方法详情

      • getTrainInputNames

        public java.util.Set<java.lang.String> getTrainInputNames()
        返回训练模型的输入名称的有序集合。
        返回值
        训练输入。
      • getTrainOutputNames

        public java.util.Set<java.lang.String> getTrainOutputNames()
        返回训练模型的输出名称的有序集合。
        返回值
        训练输出。
      • getEvalInputNames

        public java.util.Set<java.lang.String> getEvalInputNames()
        返回评估模型的输入名称的有序集合。
        返回值
        评估输入。
      • getEvalOutputNames

        public java.util.Set<java.lang.String> getEvalOutputNames()
        返回评估模型的输出名称的有序集合。
        返回值
        评估输出。
      • addProperty

        public void addProperty​(java.lang.String name,
                                float value)
                         throws OrtException
        向此训练会话检查点添加一个浮点属性。
        参数
        name - 属性名称。
        value - 属性值。
        抛出
        OrtException - 如果调用失败。
      • addProperty

        public void addProperty​(java.lang.String name,
                                int value)
                         throws OrtException
        向此训练会话检查点添加一个整型属性。
        参数
        name - 属性名称。
        value - 属性值。
        抛出
        OrtException - 如果调用失败。
      • addProperty

        public void addProperty​(java.lang.String name,
                                java.lang.String value)
                         throws OrtException
        向此训练会话检查点添加一个字符串属性。
        参数
        name - 属性名称。
        value - 属性值。
        抛出
        OrtException - 如果调用失败。
      • getFloatProperty

        public float getFloatProperty​(java.lang.String name)
                               throws OrtException
        从此训练会话检查点获取一个浮点属性。
        参数
        name - 属性名称。
        返回值
        属性值。
        抛出
        OrtException - 如果属性不存在或类型错误。
      • getIntProperty

        public int getIntProperty​(java.lang.String name)
                           throws OrtException
        从此训练会话检查点获取一个整型属性。
        参数
        name - 属性名称。
        返回值
        属性值。
        抛出
        OrtException - 如果属性不存在或类型错误。
      • getStringProperty

        public java.lang.String getStringProperty​(java.lang.String name)
                                           throws OrtException
        从此训练会话检查点获取一个字符串属性。
        参数
        name - 属性名称。
        返回值
        属性值。
        抛出
        OrtException - 如果属性不存在或类型错误。
      • close

        public void close()
        指定者
        close 在接口 java.lang.AutoCloseable 中
      • saveCheckpoint

        public void saveCheckpoint​(java.nio.file.Path outputPath,
                                   boolean saveOptimizer)
                            throws OrtException
        将训练会话状态保存到提供的检查点目录中。
        参数
        outputPath - 检查点目录的路径。
        saveOptimizer - 是否应保存优化器状态。
        抛出
        OrtException - 如果原生调用失败。
      • setSeed

        public static void setSeed​(long seed)
                            throws OrtException
        设置 ONNX Runtime 使用的 RNG 种子。

        注意,此设置在所有 OrtTrainingSession 实例中是全局的。

        参数
        seed - RNG 种子。
        抛出
        OrtException - 如果原生调用失败。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                    throws OrtException
        执行单个训练步骤,累积梯度。
        参数
        inputs - 输入(必须同时包含特征和目标)。
        返回值
        训练步骤产生的所有输出。
        抛出
        OrtException - 如果原生调用失败。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        执行单个训练步骤,累积梯度。
        参数
        inputs - 输入(必须同时包含特征和目标)。
        runOptions - 控制此特定调用的运行选项。
        返回值
        训练步骤产生的所有输出。
        抛出
        OrtException - 如果原生调用失败。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs)
                                    throws OrtException
        执行单个训练步骤,累积梯度。
        参数
        inputs - 输入(必须同时包含特征和目标)。
        requestedOutputs - 请求的输出。
        返回值
        训练步骤产生的请求输出。
        抛出
        OrtException - 如果原生调用失败。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                    throws OrtException
        执行单个训练步骤,累积梯度。

        输出根据提供的映射遍历顺序排序。

        注意:固定的输出不属于 OrtSession.Result 对象,并且在结果对象关闭时不会关闭。

        参数
        inputs - 输入(必须同时包含特征和目标)。
        pinnedOutputs - 用户已分配的请求输出。
        返回值
        训练步骤产生的请求输出。
        抛出
        OrtException - 如果原生调用失败。
      • trainStep

        public OrtSession.Result trainStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                           java.util.Set<java.lang.String> requestedOutputs,
                                           java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                           OrtSession.RunOptions runOptions)
                                    throws OrtException
        执行单个训练步骤,累积梯度。

        输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出 IllegalArgumentException

        注意:固定的输出不属于 OrtSession.Result 对象,并且在结果对象关闭时不会关闭。

        参数
        inputs - 输入(必须同时包含特征和目标)。
        requestedOutputs - ORT 将分配的请求输出。
        pinnedOutputs - 用户已分配的请求输出。
        runOptions - 控制此特定调用的运行选项。
        返回值
        训练步骤产生的请求输出。
        抛出
        OrtException - 如果原生调用失败。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs)
                                   throws OrtException
        使用提供的输入执行单个评估步骤。
        参数
        inputs - 模型输入。
        返回值
        所有模型输出。
        抛出
        OrtException - 如果原生调用失败。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        使用提供的输入执行单个评估步骤。
        参数
        inputs - 模型输入。
        runOptions - 控制此特定调用的运行选项。
        返回值
        所有模型输出。
        抛出
        OrtException - 如果原生调用失败。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs)
                                   throws OrtException
        使用提供的输入执行单个评估步骤。
        参数
        inputs - 模型输入。
        requestedOutputs - 请求的输出名称。
        返回值
        请求的输出。
        抛出
        OrtException - 如果原生调用失败。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs)
                                   throws OrtException
        使用提供的输入执行单个评估步骤。

        输出根据提供的映射遍历顺序排序。

        注意:固定的输出不属于 OrtSession.Result 对象,并且在结果对象关闭时不会关闭。

        参数
        inputs - 用于评分的输入。
        pinnedOutputs - 用户已分配的请求输出。
        返回值
        请求的输出。
        抛出
        OrtException - 如果原生调用失败。
      • evalStep

        public OrtSession.Result evalStep​(java.util.Map<java.lang.String,​? extends OnnxTensorLike> inputs,
                                          java.util.Set<java.lang.String> requestedOutputs,
                                          java.util.Map<java.lang.String,​? extends OnnxValue> pinnedOutputs,
                                          OrtSession.RunOptions runOptions)
                                   throws OrtException
        使用提供的输入执行单个评估步骤。

        输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出 IllegalArgumentException

        注意:固定的输出不属于 OrtSession.Result 对象,并且在结果对象关闭时不会关闭。

        参数
        inputs - 用于评分的输入。
        requestedOutputs - ORT 将分配的请求输出。
        pinnedOutputs - 用户已分配的请求输出。
        runOptions - 控制此特定调用的运行选项。
        返回值
        请求的输出。
        抛出
        OrtException - 如果原生调用失败。
      • setLearningRate

        public void setLearningRate​(float learningRate)
                             throws OrtException
        设置训练会话的学习率。

        仅当会话中没有学习率调度器时才应使用。不用于设置学习率调度器的初始学习率。

        参数
        learningRate - 学习率。
        抛出
        OrtException - 如果调用失败。
      • getLearningRate

        public float getLearningRate()
                              throws OrtException
        获取此训练会话的当前学习率。
        返回值
        当前学习率。
        抛出
        OrtException - 如果调用失败。
      • optimizerStep

        public void optimizerStep()
                           throws OrtException
        使用优化器模型将梯度更新应用于可训练参数。
        抛出
        OrtException - 如果原生调用失败。
      • optimizerStep

        public void optimizerStep​(OrtSession.RunOptions runOptions)
                           throws OrtException
        使用优化器模型将梯度更新应用于可训练参数。

        运行选项可用于控制日志记录和提前终止调用。

        参数
        runOptions - 控制模型执行的选项。
        抛出
        OrtException - 如果原生调用失败。
      • registerLinearLRScheduler

        public void registerLinearLRScheduler​(long warmupSteps,
                                              long totalSteps,
                                              float initialLearningRate)
                                       throws OrtException
        注册一个带有线性预热的线性学习率调度器。
        参数
        warmupSteps - 将学习率从零增加到 initialLearningRate 所需的步数。
        totalSteps - 此调度器操作的总步数。
        initialLearningRate - 最大学习率。
        抛出
        OrtException - 如果原生调用失败。
      • schedulerStep

        public void schedulerStep()
                           throws OrtException
        根据注册的学习率调度器更新学习率。
        抛出
        OrtException - 如果原生调用失败。
      • exportModelForInference

        public void exportModelForInference​(java.nio.file.Path outputPath,
                                            java.lang.String[] outputNames)
                                     throws OrtException
        将评估模型导出为适用于推理的模型,并将所需节点设置为输出节点。

        注意,此方法从提供给训练会话的路径重新加载评估模型,并且此路径必须仍然有效。

        参数
        outputPath - 写入推理模型的路径。
        outputNames - 输出节点的名称。
        抛出
        OrtException - 如果原生调用失败。