类 OrtTrainingSession
- java.lang.Object
-
- ai.onnxruntime.OrtTrainingSession
-
- 所有已实现的接口
java.lang.AutoCloseable
public final class OrtTrainingSession extends java.lang.Object implements java.lang.AutoCloseable
封装 ONNX 训练模型并允许训练和推理调用。允许检查模型的输入和输出节点。由
OrtEnvironment
生成。如果会话已关闭并调用了方法,大多数实例方法会抛出
IllegalStateException
。
-
-
方法摘要
所有方法 静态方法 实例方法 具体方法 修饰符和类型 方法 描述 void
addProperty(java.lang.String name, float value)
向此训练会话检查点添加一个浮点属性。void
addProperty(java.lang.String name, int value)
向此训练会话检查点添加一个整型属性。void
addProperty(java.lang.String name, java.lang.String value)
向此训练会话检查点添加一个字符串属性。void
close()
OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)
使用提供的输入执行单个评估步骤。OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
使用提供的输入执行单个评估步骤。OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)
使用提供的输入执行单个评估步骤。OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
使用提供的输入执行单个评估步骤。OrtSession.Result
evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
使用提供的输入执行单个评估步骤。void
exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames)
将评估模型导出为适用于推理的模型,并将所需节点设置为输出节点。java.util.Set<java.lang.String>
getEvalInputNames()
返回评估模型的输入名称的有序集合。java.util.Set<java.lang.String>
getEvalOutputNames()
返回评估模型的输出名称的有序集合。float
getFloatProperty(java.lang.String name)
从此训练会话检查点获取一个浮点属性。int
getIntProperty(java.lang.String name)
从此训练会话检查点获取一个整型属性。float
getLearningRate()
获取此训练会话的当前学习率。java.lang.String
getStringProperty(java.lang.String name)
从此训练会话检查点获取一个字符串属性。java.util.Set<java.lang.String>
getTrainInputNames()
返回训练模型的输入名称的有序集合。java.util.Set<java.lang.String>
getTrainOutputNames()
返回训练模型的输出名称的有序集合。void
lazyResetGrad()
void
optimizerStep()
使用优化器模型将梯度更新应用于可训练参数。void
optimizerStep(OrtSession.RunOptions runOptions)
使用优化器模型将梯度更新应用于可训练参数。void
registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate)
注册一个带有线性预热的线性学习率调度器。void
saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer)
将训练会话状态保存到提供的检查点目录中。void
schedulerStep()
根据注册的学习率调度器更新学习率。void
setLearningRate(float learningRate)
设置训练会话的学习率。static void
setSeed(long seed)
设置 ONNX Runtime 使用的 RNG 种子。OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs)
执行单个训练步骤,累积梯度。OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions)
执行单个训练步骤,累积梯度。OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs)
执行单个训练步骤,累积梯度。OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs)
执行单个训练步骤,累积梯度。OrtSession.Result
trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions)
执行单个训练步骤,累积梯度。
-
-
-
方法详情
-
getTrainInputNames
public java.util.Set<java.lang.String> getTrainInputNames()
返回训练模型的输入名称的有序集合。- 返回值
- 训练输入。
-
getTrainOutputNames
public java.util.Set<java.lang.String> getTrainOutputNames()
返回训练模型的输出名称的有序集合。- 返回值
- 训练输出。
-
getEvalInputNames
public java.util.Set<java.lang.String> getEvalInputNames()
返回评估模型的输入名称的有序集合。- 返回值
- 评估输入。
-
getEvalOutputNames
public java.util.Set<java.lang.String> getEvalOutputNames()
返回评估模型的输出名称的有序集合。- 返回值
- 评估输出。
-
addProperty
public void addProperty(java.lang.String name, float value) throws OrtException
向此训练会话检查点添加一个浮点属性。- 参数
name
- 属性名称。value
- 属性值。- 抛出
OrtException
- 如果调用失败。
-
addProperty
public void addProperty(java.lang.String name, int value) throws OrtException
向此训练会话检查点添加一个整型属性。- 参数
name
- 属性名称。value
- 属性值。- 抛出
OrtException
- 如果调用失败。
-
addProperty
public void addProperty(java.lang.String name, java.lang.String value) throws OrtException
向此训练会话检查点添加一个字符串属性。- 参数
name
- 属性名称。value
- 属性值。- 抛出
OrtException
- 如果调用失败。
-
getFloatProperty
public float getFloatProperty(java.lang.String name) throws OrtException
从此训练会话检查点获取一个浮点属性。- 参数
name
- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException
- 如果属性不存在或类型错误。
-
getIntProperty
public int getIntProperty(java.lang.String name) throws OrtException
从此训练会话检查点获取一个整型属性。- 参数
name
- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException
- 如果属性不存在或类型错误。
-
getStringProperty
public java.lang.String getStringProperty(java.lang.String name) throws OrtException
从此训练会话检查点获取一个字符串属性。- 参数
name
- 属性名称。- 返回值
- 属性值。
- 抛出
OrtException
- 如果属性不存在或类型错误。
-
close
public void close()
- 指定者
close
在接口java.lang.AutoCloseable
中
-
saveCheckpoint
public void saveCheckpoint(java.nio.file.Path outputPath, boolean saveOptimizer) throws OrtException
将训练会话状态保存到提供的检查点目录中。- 参数
outputPath
- 检查点目录的路径。saveOptimizer
- 是否应保存优化器状态。- 抛出
OrtException
- 如果原生调用失败。
-
lazyResetGrad
public void lazyResetGrad() throws OrtException
确保在下次调用trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)
之前,梯度被重置为零。注意,这是一个延迟调用,梯度是在运行下一个
trainStep(java.util.Map<java.lang.String, ? extends ai.onnxruntime.OnnxTensorLike>)
时清除的,而不是在此之前。- 抛出
OrtException
- 如果原生调用失败。
-
setSeed
public static void setSeed(long seed) throws OrtException
设置 ONNX Runtime 使用的 RNG 种子。注意,此设置在所有 OrtTrainingSession 实例中是全局的。
- 参数
seed
- RNG 种子。- 抛出
OrtException
- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs
- 输入(必须同时包含特征和目标)。- 返回值
- 训练步骤产生的所有输出。
- 抛出
OrtException
- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs
- 输入(必须同时包含特征和目标)。runOptions
- 控制此特定调用的运行选项。- 返回值
- 训练步骤产生的所有输出。
- 抛出
OrtException
- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
执行单个训练步骤,累积梯度。- 参数
inputs
- 输入(必须同时包含特征和目标)。requestedOutputs
- 请求的输出。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException
- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
执行单个训练步骤,累积梯度。输出根据提供的映射遍历顺序排序。
注意:固定的输出不属于
OrtSession.Result
对象,并且在结果对象关闭时不会关闭。- 参数
inputs
- 输入(必须同时包含特征和目标)。pinnedOutputs
- 用户已分配的请求输出。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException
- 如果原生调用失败。
-
trainStep
public OrtSession.Result trainStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
执行单个训练步骤,累积梯度。输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出
IllegalArgumentException
。注意:固定的输出不属于
OrtSession.Result
对象,并且在结果对象关闭时不会关闭。- 参数
inputs
- 输入(必须同时包含特征和目标)。requestedOutputs
- ORT 将分配的请求输出。pinnedOutputs
- 用户已分配的请求输出。runOptions
- 控制此特定调用的运行选项。- 返回值
- 训练步骤产生的请求输出。
- 抛出
OrtException
- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs
- 模型输入。- 返回值
- 所有模型输出。
- 抛出
OrtException
- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs
- 模型输入。runOptions
- 控制此特定调用的运行选项。- 返回值
- 所有模型输出。
- 抛出
OrtException
- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs) throws OrtException
使用提供的输入执行单个评估步骤。- 参数
inputs
- 模型输入。requestedOutputs
- 请求的输出名称。- 返回值
- 请求的输出。
- 抛出
OrtException
- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs) throws OrtException
使用提供的输入执行单个评估步骤。输出根据提供的映射遍历顺序排序。
注意:固定的输出不属于
OrtSession.Result
对象,并且在结果对象关闭时不会关闭。- 参数
inputs
- 用于评分的输入。pinnedOutputs
- 用户已分配的请求输出。- 返回值
- 请求的输出。
- 抛出
OrtException
- 如果原生调用失败。
-
evalStep
public OrtSession.Result evalStep(java.util.Map<java.lang.String,? extends OnnxTensorLike> inputs, java.util.Set<java.lang.String> requestedOutputs, java.util.Map<java.lang.String,? extends OnnxValue> pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException
使用提供的输入执行单个评估步骤。输出根据提供的集合遍历顺序排序,固定的输出在前,然后是请求的输出。如果请求的输出和固定的输出中出现相同的输出名称,则抛出
IllegalArgumentException
。注意:固定的输出不属于
OrtSession.Result
对象,并且在结果对象关闭时不会关闭。- 参数
inputs
- 用于评分的输入。requestedOutputs
- ORT 将分配的请求输出。pinnedOutputs
- 用户已分配的请求输出。runOptions
- 控制此特定调用的运行选项。- 返回值
- 请求的输出。
- 抛出
OrtException
- 如果原生调用失败。
-
setLearningRate
public void setLearningRate(float learningRate) throws OrtException
设置训练会话的学习率。仅当会话中没有学习率调度器时才应使用。不用于设置学习率调度器的初始学习率。
- 参数
learningRate
- 学习率。- 抛出
OrtException
- 如果调用失败。
-
getLearningRate
public float getLearningRate() throws OrtException
获取此训练会话的当前学习率。- 返回值
- 当前学习率。
- 抛出
OrtException
- 如果调用失败。
-
optimizerStep
public void optimizerStep() throws OrtException
使用优化器模型将梯度更新应用于可训练参数。- 抛出
OrtException
- 如果原生调用失败。
-
optimizerStep
public void optimizerStep(OrtSession.RunOptions runOptions) throws OrtException
使用优化器模型将梯度更新应用于可训练参数。运行选项可用于控制日志记录和提前终止调用。
- 参数
runOptions
- 控制模型执行的选项。- 抛出
OrtException
- 如果原生调用失败。
-
registerLinearLRScheduler
public void registerLinearLRScheduler(long warmupSteps, long totalSteps, float initialLearningRate) throws OrtException
注册一个带有线性预热的线性学习率调度器。- 参数
warmupSteps
- 将学习率从零增加到initialLearningRate
所需的步数。totalSteps
- 此调度器操作的总步数。initialLearningRate
- 最大学习率。- 抛出
OrtException
- 如果原生调用失败。
-
schedulerStep
public void schedulerStep() throws OrtException
根据注册的学习率调度器更新学习率。- 抛出
OrtException
- 如果原生调用失败。
-
exportModelForInference
public void exportModelForInference(java.nio.file.Path outputPath, java.lang.String[] outputNames) throws OrtException
将评估模型导出为适用于推理的模型,并将所需节点设置为输出节点。注意,此方法从提供给训练会话的路径重新加载评估模型,并且此路径必须仍然有效。
- 参数
outputPath
- 写入推理模型的路径。outputNames
- 输出节点的名称。- 抛出
OrtException
- 如果原生调用失败。
-
-