ONNX Runtime generate() Java API
注意:此 API 处于预览阶段,可能会发生变化。
- 安装和导入
- SimpleGenAI 类
- Exception 类
- Model 类
- Tokenizer 类
- TokenizerStream 类
- Tensor 类
- GeneratorParams 类
- Generator 类
- Sequences 类
- Adapter 类
安装和导入
此 Java API 由 ai.onnxruntime.genai Java 包提供。包的发布正在进行中。要从源代码构建此包,请参阅从源代码构建指南。
import ai.onnxruntime.genai.*;
SimpleGenAI 类
SimpleGenAI
类提供了 GenAI API 的一个简单用法示例。它适用于根据提示生成文本的模型,一次处理一个提示。用法
使用模型路径创建该类的实例。该路径还应包含 GenAI 配置文件。
SimpleGenAI genAI = new SimpleGenAI(folderPath);
使用提示文本调用 createGeneratorParams。如有需要,使用 setSearchOption
通过 GeneratorParams 对象设置任何其他搜索选项。
GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()
使用 GeneratorParams 对象和可选的监听器调用 generate。
String fullResponse = genAI.generate(generatorParams, listener);
监听器用作回调机制,以便在生成令牌时使用它们。创建一个实现 Consumer<String>
接口的类,并将该类的一个实例作为 listener
参数提供。
构造函数
public SimpleGenAI(String modelPath) throws GenAIException
抛出
GenAIException
- 失败时抛出。
generate 方法
根据 GeneratorParams 中的提示和设置生成文本。
注意:这仅处理单个输入序列(即单个提示,相当于批处理大小为 1)。
public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException
参数
generatorParams
: 运行模型所使用的提示和设置。listener
: 可选回调,用于在生成令牌时提供令牌。
注意:令牌生成将被阻塞,直到监听器的 accept
方法返回。
抛出
GenAIException
- 失败时抛出。
返回
生成的文本。
示例
SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);
logger.info("Result: " + result);
createGenerateParams 方法
创建生成器参数并添加提示文本。用户可以在运行 generate
之前通过 GeneratorParams 对象设置其他搜索选项。
public GeneratorParams createGeneratorParams(String prompt) throws GenAIException
参数
prompt
: 要编码的提示文本。
抛出
GenAIException
- 失败时抛出。
返回
生成器参数。
Exception 类
一个包含由原生层产生的错误消息和代码的异常。
构造函数
public GenAIException(String message)
示例
catch (GenAIException e) {
throw new GenAIException("Token generation loop failed.", e);
}
Model 类
构造函数
Model(String modelPath)
createTokenizer 方法
为此模型创建一个 Tokenizer 实例。模型包含确定要使用的分词器的配置信息。
public Tokenizer createTokenizer() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败
返回
分词器实例。
generate 方法
public Sequences generate(GeneratorParams generatorParams) throws GenAIException
参数
generatorParams
: 生成器参数。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
生成的序列。
示例
Sequences output = model.generate(generatorParams);
createGeneratorParams 方法
创建一个 GeneratorParams 实例,用于执行模型。
注意:GeneratorParams 内部使用 Model,因此 Model 实例必须保持有效。
public GeneratorParams createGeneratorParams() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
GeneratorParams 实例。
示例
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Tokenizer 类
encode 方法
将字符串编码为令牌 ID 序列。
public Sequences encode(String string) throws GenAIException
参数
string
: 要编码为令牌 ID 的文本。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个包含单个序列的 Sequences 对象。
示例
Sequences encodedPrompt = tokenizer.encode(prompt);
decode 方法
将令牌 ID 序列解码为文本。
public String decode(int[] sequence) throws GenAIException
参数
sequence
: 要解码为文本的令牌 ID 集合
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
序列的文本表示。
示例
String result = tokenizer.decode(output_ids);
encodeBatch 方法
将字符串数组编码为每个输入的令牌 ID 序列。
public Sequences encodeBatch(String[] strings) throws GenAIException
参数
strings
: 要编码为令牌 ID 的字符串集合。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个 Sequences 对象,每个输入字符串包含一个序列。
示例
Sequences encoded = tokenizer.encodeBatch(inputs);
decodeBatch 方法
将一批令牌 ID 序列解码为文本。
public String[] decodeBatch(Sequences sequences) throws GenAIException
参数
sequences
: 一个包含一个或多个令牌 ID 序列的 Sequences 对象。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个字符串数组,包含每个序列的文本表示。
示例
String[] decoded = tokenizer.decodeBatch(encoded);
createStream 方法
创建一个 TokenizerStream 对象用于流式分词。这与 Generator 类一起使用,以便在生成每个令牌时提供它。
public TokenizerStream createStream() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
新的 TokenizerStream 实例。
TokenizerStream 类
此类的作用是在使用 Generator.generateNextToken 时转换单个令牌。
decode 方法
public String decode(int token) throws GenAIException
参数
token
: 令牌的 int 值
抛出
GenAIException
Tensor 类
使用给定数据、形状和元素类型构造 Tensor。
public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException
参数
data
: Tensor 的数据。必须是直接 ByteBuffer。shape
: Tensor 的形状。elementType
: Tensor 中元素的类型。
抛出
GenAIException
示例
创建一个具有 32 位浮点数据的 2x2 Tensor。
long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);
GeneratorParams 类
GeneratorParams
类表示用于使用模型生成序列的参数。使用 setInput 设置提示,使用 setSearchOption 设置任何其他搜索选项。
创建 Generator Params 对象
GeneratorParams params = new GeneratorParams(model);
setSearchOption 方法
public void setSearchOption(String optionName, double value) throws GenAIException
抛出
GenAIException
示例
设置搜索选项以限制模型生成长度。
generatorParams.setSearchOption("max_length", 10);
setSearchOption 方法
public void setSearchOption(String optionName, boolean value) throws GenAIException
抛出
GenAIException
示例
generatorParams.setSearchOption("early_stopping", true);
setInput 方法
设置模型执行的提示。 sequences
是使用 Tokenizer.Encode 或 EncodeBatch 创建的。
public void setInput(Sequences sequences) throws GenAIException
参数
sequences
: 包含编码提示的序列。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
示例
generatorParams.setInput(encodedPrompt);
setInput 方法
设置模型执行的提示令牌 ID。tokenIds
是编码后的参数。
public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
throws GenAIException
参数
tokenIds
: 编码后的提示令牌 IDsequenceLength
: 每个序列的长度。batchSize
: 批处理大小。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
注意:批处理中的所有序列长度必须相同。
示例
generatorParams.setInput(tokenIds, sequenceLength, batchSize);
Generator 类
Generator 类使用模型和生成器参数生成输出。预期用法是循环直到 isDone 返回 false。在循环内,先调用 computeLogits,然后调用 generateNextToken。
可以使用 getLastTokenInSequence 检索新生成的令牌,并使用 TokenizerStream.Decode 解码。
生成过程完成后,如有需要,可以使用 GetSequence 检索完整的生成序列。
创建 Generator
使用给定的模型和生成器参数构造 Generator 对象。
Generator(Model model, GeneratorParams generatorParams)
参数
model
: 模型。params
: 生成器参数。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
isDone 方法
检查生成过程是否完成。
public boolean isDone()
返回
如果生成过程完成,则返回 true;否则返回 false。
computeLogits 方法
计算序列中下一个令牌的 logits。
public void computeLogits() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
getSequence 方法
检索指定序列索引的令牌 ID 序列。
public int[] getSequence(long sequenceIndex) throws GenAIException
参数
sequenceIndex
: 序列的索引。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
一个整数数组,包含令牌 ID 序列。
示例
int[] outputIds = output.getSequence(i);
generateNextToken 方法
生成序列中的下一个令牌。
public void generateNextToken() throws GenAIException
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
getLastTokenInSequence 方法
检索指定序列索引的序列中最后一个令牌。
public int getLastTokenInSequence(long sequenceIndex) throws GenAIException
参数
sequenceIndex
: 序列的索引。
抛出
GenAIException
- 如果调用 GenAI 原生 API 失败。
返回
序列中的最后一个令牌。
Sequences 类
表示编码后的提示/响应集合。
numSequences 方法
获取集合中序列的数量。这等同于批处理大小。
public long numSequences()
返回
序列的数量。
示例
int numSequences = (int) sequences.numSequences();
getSequence 方法
获取指定索引处的序列。
public int[] getSequence(long sequenceIndex)
参数
sequenceIndex
: 序列的索引。
返回
序列作为整数数组。
Adapter 类
即将推出!