ONNX Runtime generate() Java API

注意:此 API 处于预览阶段,可能会有更改。

安装和导入

Java API 由 ai.onnxruntime.genai Java 包提供。包的发布正在进行中。要从源代码构建包,请参阅从源代码构建指南

import ai.onnxruntime.genai.*;

SimpleGenAI 类

SimpleGenAI 类提供了 GenAI API 的一个简单使用示例。它使用一个基于提示生成文本的模型,一次处理一个提示。用法

使用模型路径创建该类的实例。该路径也应包含 GenAI 配置文件。

SimpleGenAI genAI = new SimpleGenAI(folderPath);

使用提示文本调用 createGeneratorParams。根据需要使用 setSearchOption 通过 GeneratorParams 对象设置任何其他搜索选项。

GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()

使用 GeneratorParams 对象和可选的监听器调用 generate。

String fullResponse = genAI.generate(generatorParams, listener);

监听器用作回调机制,以便可以在生成令牌时使用它们。创建一个实现 Consumer<String> 接口的类,并提供该类的一个实例作为 listener 参数。

构造函数

public SimpleGenAI(String modelPath) throws GenAIException

抛出

GenAIException - 失败时。

生成方法

根据 GeneratorParams 中的提示和设置生成文本。

注意:这只处理单个输入序列(即单个提示,相当于批处理大小为 1)。

public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException

参数

  • generatorParams:用于运行模型的提示和设置。
  • listener:可选的回调,用于在令牌生成时提供它们。

注意:令牌生成将被阻塞,直到监听器的 accept 方法返回。

抛出

GenAIException - 失败时。

返回

生成的文本。

示例

SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);

logger.info("Result: " + result);

createGenerateParams 方法

创建生成器参数并添加提示文本。用户可以在运行 generate 之前通过 GeneratorParams 对象设置其他搜索选项。

public GeneratorParams createGeneratorParams(String prompt) throws GenAIException

参数

  • prompt:要编码的提示文本。

抛出

GenAIException - 失败时。

返回

生成器参数。

异常类

一个异常,包含由原生层生成的错误消息和代码。

构造函数

public GenAIException(String message)

示例

catch (GenAIException e) {
  throw new GenAIException("Token generation loop failed.", e);
}

模型类

构造函数

Model(String modelPath)

创建分词器方法

为此模型创建 Tokenizer 实例。模型包含确定要使用的分词器的配置信息。

public Tokenizer createTokenizer() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败

返回

分词器实例。

生成方法

public Sequences generate(GeneratorParams generatorParams) throws GenAIException

参数

  • generatorParams:生成器参数。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

生成的序列。

示例

Sequences output = model.generate(generatorParams);

createGeneratorParams 方法

创建用于执行模型的 GeneratorParams 实例。

注意:GeneratorParams 内部使用 Model,因此 Model 实例必须保持有效。

public GeneratorParams createGeneratorParams() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

GeneratorParams 实例。

示例

GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");

分词器类

编码方法

将字符串编码为令牌 ID 序列。

public Sequences encode(String string) throws GenAIException

参数

  • string:要编码为令牌 ID 的文本。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个包含单个序列的 Sequences 对象。

示例

Sequences encodedPrompt = tokenizer.encode(prompt);

解码方法

将令牌 ID 序列解码为文本。

public String decode(int[] sequence) throws GenAIException

参数

  • sequence:要解码为文本的令牌 ID 集合

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

序列的文本表示。

示例

String result = tokenizer.decode(output_ids);

encodeBatch 方法

将字符串数组编码为每个输入的令牌 ID 序列。

public Sequences encodeBatch(String[] strings) throws GenAIException

参数

  • strings:要编码为令牌 ID 的字符串集合。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个 Sequences 对象,每个输入字符串包含一个序列。

示例

Sequences encoded = tokenizer.encodeBatch(inputs);

decodeBatch 方法

将一批令牌 ID 序列解码为文本。

public String[] decodeBatch(Sequences sequences) throws GenAIException

参数

  • sequences:一个 Sequences 对象,包含一个或多个令牌 ID 序列。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个字符串数组,其中包含每个序列的文本表示。

示例

String[] decoded = tokenizer.decodeBatch(encoded);

createStream 方法

创建 TokenizerStream 对象用于流式分词。这与 Generator 类一起使用,以便在生成每个令牌时提供它们。

public TokenizerStream createStream() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

新的 TokenizerStream 实例。

TokenizerStream 类

此类用于在使用 Generator.generateNextToken 时转换单个令牌。

解码方法

public String decode(int token) throws GenAIException

参数

  • token:令牌的 int 值

抛出

GenAIException

张量类

使用给定数据、形状和元素类型构造一个张量。

public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException

参数

  • data:张量的数据。必须是直接 ByteBuffer。
  • shape:张量的形状。
  • elementType:张量中元素的类型。

抛出

GenAIException

示例

创建一个具有 32 位浮点数据的 2x2 张量。

long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});

Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);

GeneratorParams 类

GeneratorParams 类表示用于使用模型生成序列的参数。使用 setInput 设置提示,并使用 setSearchOption 设置任何其他搜索选项。

创建一个 Generator Params 对象

GeneratorParams params = new GeneratorParams(model);

setSearchOption 方法

public void setSearchOption(String optionName, double value) throws GenAIException

抛出

GenAIException

示例

设置搜索选项以限制模型生成长度。

generatorParams.setSearchOption("max_length", 10);

setSearchOption 方法

public void setSearchOption(String optionName, boolean value) throws GenAIException

抛出

GenAIException

示例

generatorParams.setSearchOption("early_stopping", true);

setInput 方法

设置模型执行的提示。 sequences 是通过使用 Tokenizer.Encode 或 EncodeBatch 创建的。

public void setInput(Sequences sequences) throws GenAIException

参数

  • sequences:包含编码提示的序列。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

示例

generatorParams.setInput(encodedPrompt);

setInput 方法

设置模型执行的提示/令牌 ID。 tokenIds 是编码后的参数。

public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
 throws GenAIException

参数

  • tokenIds:编码提示的令牌 ID
  • sequenceLength:每个序列的长度。
  • batchSize:批次大小。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

注意:批次中的所有序列必须具有相同的长度。

示例

generatorParams.setInput(tokenIds, sequenceLength, batchSize);

Generator 类

Generator 类使用模型和生成器参数生成输出。预期的用法是循环直到 isDone 返回 false。在循环中,先调用 computeLogits,然后调用 generateNextToken。

新生成的令牌可以使用 getLastTokenInSequence 获取,并使用 TokenizerStream.Decode 解码。

生成过程完成后,如果需要,可以使用 GetSequence 检索完整的生成序列。

创建一个 Generator

使用给定模型和生成器参数构造一个 Generator 对象。

Generator(Model model, GeneratorParams generatorParams)

参数

  • model:模型。
  • params:生成器参数。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

isDone 方法

检查生成过程是否完成。

public boolean isDone()

返回

如果生成过程完成,则返回 true,否则返回 false。

computeLogits 方法

计算序列中下一个令牌的 logits。

public void computeLogits() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

getSequence 方法

检索指定序列索引的令牌 ID 序列。

public int[] getSequence(long sequenceIndex) throws GenAIException

参数

  • sequenceIndex:序列的索引。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个包含令牌 ID 序列的整数数组。

示例

int[] outputIds = output.getSequence(i);

generateNextToken 方法

生成序列中的下一个令牌。

public void generateNextToken() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

getLastTokenInSequence 方法

检索指定序列索引的序列中的最后一个令牌。

public int getLastTokenInSequence(long sequenceIndex) throws GenAIException

参数

  • sequenceIndex:序列的索引。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

序列中的最后一个令牌。

Sequences 类

表示编码提示/响应的集合。

numSequences 方法

获取集合中的序列数量。这等同于批次大小。

public long numSequences()

返回

序列的数量。

示例

int numSequences = (int) sequences.numSequences();

getSequence 方法

获取指定索引处的序列。

public int[] getSequence(long sequenceIndex)

参数

  • sequenceIndex:序列的索引。

返回

序列(整数数组形式)。

适配器类

即将推出!