ONNX Runtime generate() Java API

注意:此 API 处于预览阶段,可能会发生变化。

安装和导入

此 Java API 由 ai.onnxruntime.genai Java 包提供。包的发布正在进行中。要从源代码构建此包,请参阅从源代码构建指南

import ai.onnxruntime.genai.*;

SimpleGenAI 类

SimpleGenAI 类提供了 GenAI API 的一个简单用法示例。它适用于根据提示生成文本的模型,一次处理一个提示。用法

使用模型路径创建该类的实例。该路径还应包含 GenAI 配置文件。

SimpleGenAI genAI = new SimpleGenAI(folderPath);

使用提示文本调用 createGeneratorParams。如有需要,使用 setSearchOption 通过 GeneratorParams 对象设置任何其他搜索选项。

GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()

使用 GeneratorParams 对象和可选的监听器调用 generate。

String fullResponse = genAI.generate(generatorParams, listener);

监听器用作回调机制,以便在生成令牌时使用它们。创建一个实现 Consumer<String> 接口的类,并将该类的一个实例作为 listener 参数提供。

构造函数

public SimpleGenAI(String modelPath) throws GenAIException

抛出

GenAIException - 失败时抛出。

generate 方法

根据 GeneratorParams 中的提示和设置生成文本。

注意:这仅处理单个输入序列(即单个提示,相当于批处理大小为 1)。

public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException

参数

  • generatorParams: 运行模型所使用的提示和设置。
  • listener: 可选回调,用于在生成令牌时提供令牌。

注意:令牌生成将被阻塞,直到监听器的 accept 方法返回。

抛出

GenAIException - 失败时抛出。

返回

生成的文本。

示例

SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);

logger.info("Result: " + result);

createGenerateParams 方法

创建生成器参数并添加提示文本。用户可以在运行 generate 之前通过 GeneratorParams 对象设置其他搜索选项。

public GeneratorParams createGeneratorParams(String prompt) throws GenAIException

参数

  • prompt: 要编码的提示文本。

抛出

GenAIException - 失败时抛出。

返回

生成器参数。

Exception 类

一个包含由原生层产生的错误消息和代码的异常。

构造函数

public GenAIException(String message)

示例

catch (GenAIException e) {
  throw new GenAIException("Token generation loop failed.", e);
}

Model 类

构造函数

Model(String modelPath)

createTokenizer 方法

为此模型创建一个 Tokenizer 实例。模型包含确定要使用的分词器的配置信息。

public Tokenizer createTokenizer() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败

返回

分词器实例。

generate 方法

public Sequences generate(GeneratorParams generatorParams) throws GenAIException

参数

  • generatorParams: 生成器参数。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

生成的序列。

示例

Sequences output = model.generate(generatorParams);

createGeneratorParams 方法

创建一个 GeneratorParams 实例,用于执行模型。

注意:GeneratorParams 内部使用 Model,因此 Model 实例必须保持有效。

public GeneratorParams createGeneratorParams() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

GeneratorParams 实例。

示例

GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");

Tokenizer 类

encode 方法

将字符串编码为令牌 ID 序列。

public Sequences encode(String string) throws GenAIException

参数

  • string: 要编码为令牌 ID 的文本。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个包含单个序列的 Sequences 对象。

示例

Sequences encodedPrompt = tokenizer.encode(prompt);

decode 方法

将令牌 ID 序列解码为文本。

public String decode(int[] sequence) throws GenAIException

参数

  • sequence: 要解码为文本的令牌 ID 集合

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

序列的文本表示。

示例

String result = tokenizer.decode(output_ids);

encodeBatch 方法

将字符串数组编码为每个输入的令牌 ID 序列。

public Sequences encodeBatch(String[] strings) throws GenAIException

参数

  • strings: 要编码为令牌 ID 的字符串集合。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个 Sequences 对象,每个输入字符串包含一个序列。

示例

Sequences encoded = tokenizer.encodeBatch(inputs);

decodeBatch 方法

将一批令牌 ID 序列解码为文本。

public String[] decodeBatch(Sequences sequences) throws GenAIException

参数

  • sequences: 一个包含一个或多个令牌 ID 序列的 Sequences 对象。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个字符串数组,包含每个序列的文本表示。

示例

String[] decoded = tokenizer.decodeBatch(encoded);

createStream 方法

创建一个 TokenizerStream 对象用于流式分词。这与 Generator 类一起使用,以便在生成每个令牌时提供它。

public TokenizerStream createStream() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

新的 TokenizerStream 实例。

TokenizerStream 类

此类的作用是在使用 Generator.generateNextToken 时转换单个令牌。

decode 方法

public String decode(int token) throws GenAIException

参数

  • token: 令牌的 int 值

抛出

GenAIException

Tensor 类

使用给定数据、形状和元素类型构造 Tensor。

public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException

参数

  • data: Tensor 的数据。必须是直接 ByteBuffer。
  • shape: Tensor 的形状。
  • elementType: Tensor 中元素的类型。

抛出

GenAIException

示例

创建一个具有 32 位浮点数据的 2x2 Tensor。

long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});

Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);

GeneratorParams 类

GeneratorParams 类表示用于使用模型生成序列的参数。使用 setInput 设置提示,使用 setSearchOption 设置任何其他搜索选项。

创建 Generator Params 对象

GeneratorParams params = new GeneratorParams(model);

setSearchOption 方法

public void setSearchOption(String optionName, double value) throws GenAIException

抛出

GenAIException

示例

设置搜索选项以限制模型生成长度。

generatorParams.setSearchOption("max_length", 10);

setSearchOption 方法

public void setSearchOption(String optionName, boolean value) throws GenAIException

抛出

GenAIException

示例

generatorParams.setSearchOption("early_stopping", true);

setInput 方法

设置模型执行的提示。 sequences 是使用 Tokenizer.Encode 或 EncodeBatch 创建的。

public void setInput(Sequences sequences) throws GenAIException

参数

  • sequences: 包含编码提示的序列。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

示例

generatorParams.setInput(encodedPrompt);

setInput 方法

设置模型执行的提示令牌 ID。tokenIds 是编码后的参数。

public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
 throws GenAIException

参数

  • tokenIds: 编码后的提示令牌 ID
  • sequenceLength: 每个序列的长度。
  • batchSize: 批处理大小。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

注意:批处理中的所有序列长度必须相同。

示例

generatorParams.setInput(tokenIds, sequenceLength, batchSize);

Generator 类

Generator 类使用模型和生成器参数生成输出。预期用法是循环直到 isDone 返回 false。在循环内,先调用 computeLogits,然后调用 generateNextToken。

可以使用 getLastTokenInSequence 检索新生成的令牌,并使用 TokenizerStream.Decode 解码。

生成过程完成后,如有需要,可以使用 GetSequence 检索完整的生成序列。

创建 Generator

使用给定的模型和生成器参数构造 Generator 对象。

Generator(Model model, GeneratorParams generatorParams)

参数

  • model: 模型。
  • params: 生成器参数。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

isDone 方法

检查生成过程是否完成。

public boolean isDone()

返回

如果生成过程完成,则返回 true;否则返回 false。

computeLogits 方法

计算序列中下一个令牌的 logits。

public void computeLogits() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

getSequence 方法

检索指定序列索引的令牌 ID 序列。

public int[] getSequence(long sequenceIndex) throws GenAIException

参数

  • sequenceIndex: 序列的索引。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

一个整数数组,包含令牌 ID 序列。

示例

int[] outputIds = output.getSequence(i);

generateNextToken 方法

生成序列中的下一个令牌。

public void generateNextToken() throws GenAIException

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

getLastTokenInSequence 方法

检索指定序列索引的序列中最后一个令牌。

public int getLastTokenInSequence(long sequenceIndex) throws GenAIException

参数

  • sequenceIndex: 序列的索引。

抛出

GenAIException - 如果调用 GenAI 原生 API 失败。

返回

序列中的最后一个令牌。

Sequences 类

表示编码后的提示/响应集合。

numSequences 方法

获取集合中序列的数量。这等同于批处理大小。

public long numSequences()

返回

序列的数量。

示例

int numSequences = (int) sequences.numSequences();

getSequence 方法

获取指定索引处的序列。

public int[] getSequence(long sequenceIndex)

参数

  • sequenceIndex: 序列的索引。

返回

序列作为整数数组。

Adapter 类

即将推出!