ONNX Runtime generate() Java API
注意:此 API 处于预览阶段,可能会有更改。
- 安装和导入
- SimpleGenAI 类
- 异常类
- 模型类
- 分词器类
- TokenizerStream 类
- 张量类
- GeneratorParams 类
- Generator 类
- Sequences 类
- 适配器类
安装和导入
Java API 由 ai.onnxruntime.genai Java 包提供。包的发布正在进行中。要从源代码构建包,请参阅从源代码构建指南。
import ai.onnxruntime.genai.*;
SimpleGenAI 类
SimpleGenAI
类提供了 GenAI API 的一个简单使用示例。它使用一个基于提示生成文本的模型,一次处理一个提示。用法
使用模型路径创建该类的实例。该路径也应包含 GenAI 配置文件。
SimpleGenAI genAI = new SimpleGenAI(folderPath);
使用提示文本调用 createGeneratorParams。根据需要使用 setSearchOption
通过 GeneratorParams 对象设置任何其他搜索选项。
GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()
使用 GeneratorParams 对象和可选的监听器调用 generate。
String fullResponse = genAI.generate(generatorParams, listener);
监听器用作回调机制,以便可以在生成令牌时使用它们。创建一个实现 Consumer<String>
接口的类,并提供该类的一个实例作为 listener
参数。
构造函数
public SimpleGenAI(String modelPath) throws GenAIException
抛出
GenAIException
- 失败时。
生成方法
根据 GeneratorParams 中的提示和设置生成文本。
注意:这只处理单个输入序列(即单个提示,相当于批处理大小为 1)。
public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException
参数
generatorParams
:用于运行模型的提示和设置。listener
:可选的回调,用于在令牌生成时提供它们。
注意:令牌生成将被阻塞,直到监听器的 accept
方法返回。
抛出
GenAIException
- 失败时。
返回
生成的文本。
示例
SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);
logger.info("Result: " + result);
createGenerateParams 方法
创建生成器参数并添加提示文本。用户可以在运行 generate
之前通过 GeneratorParams 对象设置其他搜索选项。
public GeneratorParams createGeneratorParams(String prompt) throws GenAIException
参数
prompt
:要编码的提示文本。
抛出
GenAIException
- 失败时。
返回
生成器参数。
异常类
一个异常,包含由原生层生成的错误消息和代码。
构造函数
public GenAIException(String message)
示例
catch (GenAIException e) {
throw new GenAIException("Token generation loop failed.", e);
}
模型类
构造函数
Model(String modelPath)
创建分词器方法
为此模型创建 Tokenizer 实例。模型包含确定要使用的分词器的配置信息。
public Tokenizer createTokenizer() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败
返回
分词器实例。
生成方法
public Sequences generate(GeneratorParams generatorParams) throws GenAIException
参数
generatorParams
:生成器参数。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
生成的序列。
示例
Sequences output = model.generate(generatorParams);
createGeneratorParams 方法
创建用于执行模型的 GeneratorParams 实例。
注意:GeneratorParams 内部使用 Model,因此 Model 实例必须保持有效。
public GeneratorParams createGeneratorParams() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
GeneratorParams 实例。
示例
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
分词器类
编码方法
将字符串编码为令牌 ID 序列。
public Sequences encode(String string) throws GenAIException
参数
string
:要编码为令牌 ID 的文本。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个包含单个序列的 Sequences 对象。
示例
Sequences encodedPrompt = tokenizer.encode(prompt);
解码方法
将令牌 ID 序列解码为文本。
public String decode(int[] sequence) throws GenAIException
参数
sequence
:要解码为文本的令牌 ID 集合
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
序列的文本表示。
示例
String result = tokenizer.decode(output_ids);
encodeBatch 方法
将字符串数组编码为每个输入的令牌 ID 序列。
public Sequences encodeBatch(String[] strings) throws GenAIException
参数
strings
:要编码为令牌 ID 的字符串集合。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个 Sequences 对象,每个输入字符串包含一个序列。
示例
Sequences encoded = tokenizer.encodeBatch(inputs);
decodeBatch 方法
将一批令牌 ID 序列解码为文本。
public String[] decodeBatch(Sequences sequences) throws GenAIException
参数
sequences
:一个 Sequences 对象,包含一个或多个令牌 ID 序列。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个字符串数组,其中包含每个序列的文本表示。
示例
String[] decoded = tokenizer.decodeBatch(encoded);
createStream 方法
创建 TokenizerStream 对象用于流式分词。这与 Generator 类一起使用,以便在生成每个令牌时提供它们。
public TokenizerStream createStream() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
新的 TokenizerStream 实例。
TokenizerStream 类
此类用于在使用 Generator.generateNextToken 时转换单个令牌。
解码方法
public String decode(int token) throws GenAIException
参数
token
:令牌的 int 值
抛出
GenAIException
张量类
使用给定数据、形状和元素类型构造一个张量。
public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException
参数
data
:张量的数据。必须是直接 ByteBuffer。shape
:张量的形状。elementType
:张量中元素的类型。
抛出
GenAIException
示例
创建一个具有 32 位浮点数据的 2x2 张量。
long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);
GeneratorParams 类
GeneratorParams
类表示用于使用模型生成序列的参数。使用 setInput 设置提示,并使用 setSearchOption 设置任何其他搜索选项。
创建一个 Generator Params 对象
GeneratorParams params = new GeneratorParams(model);
setSearchOption 方法
public void setSearchOption(String optionName, double value) throws GenAIException
抛出
GenAIException
示例
设置搜索选项以限制模型生成长度。
generatorParams.setSearchOption("max_length", 10);
setSearchOption 方法
public void setSearchOption(String optionName, boolean value) throws GenAIException
抛出
GenAIException
示例
generatorParams.setSearchOption("early_stopping", true);
setInput 方法
设置模型执行的提示。 sequences
是通过使用 Tokenizer.Encode 或 EncodeBatch 创建的。
public void setInput(Sequences sequences) throws GenAIException
参数
sequences
:包含编码提示的序列。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
示例
generatorParams.setInput(encodedPrompt);
setInput 方法
设置模型执行的提示/令牌 ID。 tokenIds
是编码后的参数。
public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
throws GenAIException
参数
tokenIds
:编码提示的令牌 IDsequenceLength
:每个序列的长度。batchSize
:批次大小。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
注意:批次中的所有序列必须具有相同的长度。
示例
generatorParams.setInput(tokenIds, sequenceLength, batchSize);
Generator 类
Generator 类使用模型和生成器参数生成输出。预期的用法是循环直到 isDone 返回 false。在循环中,先调用 computeLogits,然后调用 generateNextToken。
新生成的令牌可以使用 getLastTokenInSequence 获取,并使用 TokenizerStream.Decode 解码。
生成过程完成后,如果需要,可以使用 GetSequence 检索完整的生成序列。
创建一个 Generator
使用给定模型和生成器参数构造一个 Generator 对象。
Generator(Model model, GeneratorParams generatorParams)
参数
model
:模型。params
:生成器参数。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
isDone 方法
检查生成过程是否完成。
public boolean isDone()
返回
如果生成过程完成,则返回 true,否则返回 false。
computeLogits 方法
计算序列中下一个令牌的 logits。
public void computeLogits() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
getSequence 方法
检索指定序列索引的令牌 ID 序列。
public int[] getSequence(long sequenceIndex) throws GenAIException
参数
sequenceIndex
:序列的索引。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个包含令牌 ID 序列的整数数组。
示例
int[] outputIds = output.getSequence(i);
generateNextToken 方法
生成序列中的下一个令牌。
public void generateNextToken() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
getLastTokenInSequence 方法
检索指定序列索引的序列中的最后一个令牌。
public int getLastTokenInSequence(long sequenceIndex) throws GenAIException
参数
sequenceIndex
:序列的索引。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
序列中的最后一个令牌。
Sequences 类
表示编码提示/响应的集合。
numSequences 方法
获取集合中的序列数量。这等同于批次大小。
public long numSequences()
返回
序列的数量。
示例
int numSequences = (int) sequences.numSequences();
getSequence 方法
获取指定索引处的序列。
public int[] getSequence(long sequenceIndex)
参数
sequenceIndex
:序列的索引。
返回
序列(整数数组形式)。
适配器类
即将推出!