类 OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- 所有已实现的接口
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseable封装 ONNX 训练模型并允许训练和推理调用。允许检查模型的输入和输出节点。由
OrtEnvironment生成。如果会话已关闭并调用了方法,大多数实例方法会抛出
IllegalStateException。
-
-
方法摘要
所有方法 静态方法 实例方法 具体方法 修饰符和类型 方法 描述 voidaddProperty(java.lang.String name, float value)向此训练会话检查点添加一个浮点属性。voidaddProperty(java.lang.String name, int value)向此训练会话检查点添加一个整型属性。voidaddProperty(java.lang.String name, java.lang.String value)向此训练会话检查点添加一个字符串属性。voidclose()OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)使用提供的输入执行单个评估步骤。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)使用提供的输入执行单个评估步骤。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)使用提供的输入执行单个评估步骤。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)使用提供的输入执行单个评估步骤。OrtSession.ResultevalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)使用提供的输入执行单个评估步骤。voidexportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)将评估模型导出为适用于推理的模型,并将所需节点设置为输出节点。java.util.Set<java.lang.String>getEvalInputNames()返回评估模型的输入名称的有序集合。java.util.Set<java.lang.String>getEvalOutputNames()返回评估模型的输出名称的有序集合。floatgetFloatProperty(java.lang.String name)从此训练会话检查点获取一个浮点属性。intgetIntProperty(java.lang.String name)从此训练会话检查点获取一个整型属性。floatgetLearningRate()获取此训练会话的当前学习率。java.lang.StringgetStringProperty(java.lang.String name)从此训练会话检查点获取一个字符串属性。java.util.Set<java.lang.String>getTrainInputNames()返回训练模型的输入名称的有序集合。java.util.Set<java.lang.String>getTrainOutputNames()返回训练模型的输出名称的有序集合。voidlazyResetGrad()voidoptimizerStep()使用优化器模型将梯度更新应用于可训练参数。voidoptimizerStep(OrtSession.RunOptions runOptions)使用优化器模型将梯度更新应用于可训练参数。voidregisterLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)注册一个带有线性预热的线性学习率调度器。voidsaveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)将训练会话状态保存到提供的检查点目录中。voidschedulerStep()根据注册的学习率调度器更新学习率。voidsetLearningRate(float learningRate)设置训练会话的学习率。static voidsetSeed(long seed)设置 ONNX Runtime 使用的 RNG 种子。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)执行单个训练步骤,累积梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)执行单个训练步骤,累积梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)执行单个训练步骤,累积梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)执行单个训练步骤,累积梯度。OrtSession.ResulttrainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)执行单个训练步骤,累积梯度。
-
-
-
方法详情
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
返回训练模型的输入名称的有序集合。- 返回值
- 训练输入。
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
返回训练模型的输出名称的有序集合。- 返回值
- 训练输出。
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
返回评估模型的输入名称的有序集合。- 返回值
- 评估输入。
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
返回评估模型的输出名称的有序集合。- 返回值
- 评估输出。
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtException向此训练会话检查点添加一个浮点属性。- 参数
name- 属性名称。value- 属性值。- 抛出
OrtException- 如果调用失败。
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtException向此训练会话检查点添加一个整型属性。- 参数
name- 属性名称。value- 属性值。- 抛出
OrtException- 如果调用失败。
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtException向此训练会话检查点添加一个字符串属性。- 参数
name- 属性名称。value- 属性值。- 抛出
OrtException- 如果调用失败。
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtException从此训练会话检查点获取一个浮点属性。- 参数
name- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException- 如果属性不存在或类型错误。
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtException从此训练会话检查点获取一个整型属性。- 参数
name- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException- 如果属性不存在或类型错误。
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtException从此训练会话检查点获取一个字符串属性。- 参数
name- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException- 如果属性不存在或类型错误。
-
close
public void close()
- 指定者
close在接口java.lang.AutoCloseable中
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtException将训练会话状态保存到提供的检查点目录中。- 参数
outputPath- 检查点目录的路径。saveOptimizer- 是否应保存优化器状态。- 抛出
OrtException- 如果原生调用失败。
-
lazyResetGrad
public void lazyResetGrad() throws OrtException确保在下次调用trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)之前,梯度被重置为零。注意,这是一个延迟调用,梯度是在运行下一个
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)时清除的,而不是在此之前。- 抛出
OrtException- 如果原生调用失败。
-
setSeed
public static void setSeed(long seed) throws OrtException设置 ONNX Runtime 使用的 RNG 种子。注意,此设置在所有 OrtTrainingSession 实例中是全局的。
- 参数
seed- RNG 种子。- 抛出
OrtException- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs- 输入(必须同时包含特征和目标)。- 返回值
- 训练步骤产生的所有输出。
- 抛出
OrtException- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs- 输入(必须同时包含特征和目标)。runOptions- 控制此特定调用的运行选项。- 返回值
- 训练步骤产生的所有输出。
- 抛出
OrtException- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs- 输入(必须同时包含特征和目标)。requestedOutputs- 请求的输出。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
执行单个训练步骤,累积梯度。输出根据提供的映射遍历顺序排序。
注意:固定的输出不属于
OrtSession.Result对象,并且在结果对象关闭时不会关闭。- 参数
inputs- 输入(必须同时包含特征和目标)。pinnedOutputs- 用户已分配的请求输出。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
执行单个训练步骤,累积梯度。输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出
IllegalArgumentException。注意:固定的输出不属于
OrtSession.Result对象,并且在结果对象关闭时不会关闭。- 参数
inputs- 输入(必须同时包含特征和目标)。requestedOutputs- ORT 将分配的请求输出。pinnedOutputs- 用户已分配的请求输出。runOptions- 控制此特定调用的运行选项。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs- 模型输入。- 返回值
- 所有模型输出。
- 抛出
OrtException- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs- 模型输入。runOptions- 控制此特定调用的运行选项。- 返回值
- 所有模型输出。
- 抛出
OrtException- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs- 模型输入。requestedOutputs- 请求的输出名称。- 返回值
- 请求的输出。
- 抛出
OrtException- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
使用提供的输入执行单个评估步骤。输出根据提供的映射遍历顺序排序。
注意:固定的输出不属于
OrtSession.Result对象,并且在结果对象关闭时不会关闭。- 参数
inputs- 用于评分的输入。pinnedOutputs- 用户已分配的请求输出。- 返回值
- 请求的输出。
- 抛出
OrtException- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的输入执行单个评估步骤。输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出
IllegalArgumentException。注意:固定的输出不属于
OrtSession.Result对象,并且在结果对象关闭时不会关闭。- 参数
inputs- 用于评分的输入。requestedOutputs- ORT 将分配的请求输出。pinnedOutputs- 用户已分配的请求输出。runOptions- 控制此特定调用的运行选项。- 返回值
- 请求的输出。
- 抛出
OrtException- 如果原生调用失败。
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtException设置训练会话的学习率。仅当会话中没有学习率调度器时才应使用。不用于设置学习率调度器的初始学习率。
- 参数
learningRate- 学习率。- 抛出
OrtException- 如果调用失败。
-
getLearningRate
public float getLearningRate() throws OrtException获取此训练会话的当前学习率。- 返回值
- 当前学习率。
- 抛出
OrtException- 如果调用失败。
-
optimizerStep
public void optimizerStep() throws OrtException使用优化器模型将梯度更新应用于可训练参数。- 抛出
OrtException- 如果原生调用失败。
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
使用优化器模型将梯度更新应用于可训练参数。运行选项可用于控制日志记录和提前终止调用。
- 参数
runOptions- 控制模型执行的选项。- 抛出
OrtException- 如果原生调用失败。
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException注册一个带有线性预热的线性学习率调度器。- 参数
warmupSteps- 将学习率从零增加到initialLearningRate所需的步数。totalSteps- 此调度器操作的总步数。initialLearningRate- 最大学习率。- 抛出
OrtException- 如果原生调用失败。
-
schedulerStep
public void schedulerStep() throws OrtException根据注册的学习率调度器更新学习率。- 抛出
OrtException- 如果原生调用失败。
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtException将评估模型导出为适用于推理的模型,并将所需节点设置为输出节点。注意,此方法从提供给训练会话的路径重新加载评估模型,并且此路径必须仍然有效。
- 参数
outputPath- 写入推理模型的路径。outputNames- 输出节点的名称。- 抛出
OrtException- 如果原生调用失败。
-
-