ONNX Runtime generate() C API

注意:此 API 处于预览阶段,可能会发生变化。

概述

模型 API

创建模型

从给定目录创建模型。该目录应包含一个名为 genai_config.json 的文件,该文件对应于配置规范

参数

  • 输入: config_path 模型配置文件目录的路径。路径应使用 UTF-8 编码。
  • 输出: out 创建的模型。

返回值

如果模型创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out);

销毁模型

销毁给定的模型。

参数

  • 输入: model 要销毁的模型。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model);

生成

根据给定的生成器参数,从模型执行中生成一个包含多个标记数组的数组。

参数

  • 输入: model 用于生成的模型。
  • 输入: generator_params 用于生成的参数。
  • 输出: out 生成的标记序列。调用方在使用完这些序列后,负责使用 OgaDestroySequences 释放它们。

返回值

如果生成失败,返回包含错误消息的 OgaResult。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);

分词器 API

创建分词器

参数

  • 输入: model. 应为其创建分词器的模型。

返回值

如果分词器创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizer(const OgaModel* model, OgaTokenizer** out);

销毁分词器

OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizer(OgaTokenizer*);

编码

编码单个字符串并将编码后的标记序列添加到 OgaSequences 中。当不再需要 OgaSequences 时,必须使用 OgaDestroySequences 释放它。

参数

返回值

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

解码

解码单个标记序列并返回一个以 null 结尾的 utf8 字符串。out_string 必须使用 OgaDestroyString 释放。

参数

返回值

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string);

批量编码

参数

  • OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer*, const char** strings, size_t count, TokenSequences** out);
    

批量解码

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatch(const OgaTokenizer*, const OgaSequences* tokens, const char*** out_strings);

销毁分词器字符串

OGA_EXPORT void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count);

创建分词器流

OgaTokenizerStream 用于增量解码标记字符串,一次解码一个标记。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateTokenizerStream(const OgaTokenizer*, OgaTokenizerStream** out);

销毁分词器流

参数

OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizerStream(OgaTokenizerStream*);

解码流

在流中解码单个标记。如果这导致生成一个词,它将通过 'out' 返回。调用者负责将每个块连接起来以生成完整结果。'out' 在下一次调用 OgaTokenizerStreamDecode 之前或 OgaTokenizerStream 被销毁之前有效。

OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream*, int32_t token, const char** out);

生成器参数 API

创建生成器参数

从给定模型创建 OgaGeneratorParams。

参数

  • 输入: model 用于生成的模型。
  • 输出: out 创建的生成器参数。

返回值

如果生成器参数创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out);

销毁生成器参数

销毁给定的生成器参数。

参数

  • 输入: generator_params 要销毁的生成器参数。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyGeneratorParams(OgaGeneratorParams* generator_params);

设置搜索选项 (数字)

设置一个搜索选项,该选项为数字类型

参数

  • generator_params: 要设置参数的生成器参数对象
  • name: 参数名称
  • value: 要设置的值

返回值

如果生成器参数创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value);

设置搜索选项 (布尔值)

设置一个搜索选项,该选项为布尔类型。

参数

  • generator_params: 要设置参数的生成器参数对象
  • name: 参数名称
  • value: 要设置的值

返回值

如果生成器参数创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* generator_params, const char* name, bool value);

尝试使用最大批量大小进行图捕获

图捕获将计算图中的动态元素固定为常量值。这可以在某些环境中提供更高效的执行。要在图捕获模式下执行,需要提前知道最大批量大小。如果内存不足以分配指定的最大批量大小,此函数可能会失败。

参数

  • generator_params: 要设置参数的生成器参数对象
  • max_batch_size: 要分配的最大批量大小

返回值

如果无法使用指定的批量大小配置图捕获模式,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size);

设置输入

设置生成器参数的输入 id。输入 id 用于启动生成。

参数

  • 输入: generator_params 要设置输入 id 的生成器参数。
  • 输入: input_ids 输入 id 数组,大小为 input_ids_count = batch_size * sequence_length。
  • 输入: input_ids_count 输入 id 的总数。
  • 输入: sequence_length 输入 id 的序列长度。
  • 输入: batch_size 输入 id 的批量大小。

返回值

如果输入 id 设置失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids, size_t input_ids_count, size_t sequence_length, size_t batch_size);

设置输入序列

设置生成器参数的输入 id 序列。输入 id 序列用于启动生成。

参数

  • 输入: generator_params 要设置输入 id 的生成器参数。
  • 输入: sequences 输入 id 序列。

返回值

如果输入 id 序列设置失败,返回包含错误消息的 OgaResult。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences);

设置模型输入

设置除 input_ids 之外的额外模型输入。

参数

  • generator_params: 要设置输入的生成器参数
  • name: 要设置的参数名称
  • tensor: 参数的值

返回值

如果输入设置失败,返回包含错误消息的 OgaResult。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, OgaTensor* tensor);

生成器 API

创建生成器

从给定模型和生成器参数创建生成器。

参数

  • 输入: model 用于生成的模型。
  • 输入: params 用于生成的参数。
  • 输出: out 创建的生成器。

返回值

如果生成器创建失败,返回包含错误消息的 OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);

销毁生成器

销毁给定的生成器。

参数

  • 输入: generator 要销毁的生成器。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator);

检查生成是否完成

如果生成器已完成生成所有序列,则返回 true。

参数

  • 输入: generator 要检查是否已完成生成所有序列的生成器。

返回值

如果生成器已完成生成所有序列,则返回 True,否则返回 false。

OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator);

运行模型的一次迭代

根据输入 id 和先前的状态计算模型的 logits。计算出的 logits 存储在生成器中。

参数

  • 输入: generator 要计算 logits 的生成器。

返回值

如果 logits 计算失败,返回包含错误消息的 OgaResult。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator);

生成下一个标记

使用配置的生成参数,根据计算出的 logits 生成下一个标记。

参数

  • 输入: generator 要为其生成下一个标记的生成器。

返回值

如果下一个标记生成失败,返回包含错误消息的 OgaResult。

OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator);

获取标记数量

返回给定索引处序列中的标记数量。

参数

  • 输入: generator 用于获取给定索引处序列标记数量的生成器。
  • 输入: index. 要返回标记的索引。

返回值

给定索引处序列中的标记数量。

OGA_EXPORT size_t OGA_API_CALL OgaGenerator_GetSequenceCount(const OgaGenerator* generator, size_t index);

获取序列

返回指向给定索引处序列数据的指针。序列中的标记数量由 OgaGenerator_GetSequenceCount 提供。

参数

  • 输入: generator 用于获取给定索引处序列数据指针的生成器。序列数据由 OgaGenerator 拥有,并在 OgaGenerator 销毁时释放。如果需要在 OgaGenerator 销毁后使用数据,调用方必须复制数据。
  • 输入: index. 要获取序列的索引。

返回值

指向标记序列的指针

OGA_EXPORT const int32_t* OGA_API_CALL OgaGenerator_GetSequenceData(const OgaGenerator* generator, size_t index);

设置运行时选项

一个用于设置运行时选项的 API,未来会向此通用 API 添加更多参数以支持运行时选项。例如,要使用此 API 终止当前会话,可以调用 SetRuntimeOption,其中 key 为 “terminate_session”,value 为 “1”:OgaGenerator_SetRuntimeOption(generator, “terminate_session”, “1”)

有关当前运行时选项的更多详细信息,请参见此处

参数

  • 输入: generator 需要设置运行时选项的生成器。
  • 输入: key 设置运行时选项的键。
  • 输入: value 提供给键的值。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaGenerator_SetRuntimeOption(OgaGenerator* generator, const char* key, const char* value);

适配器 API

此 API 用于加载和切换微调适配器,例如 LoRA 适配器。

创建适配器管理器

创建用于管理适配器的对象。此对象用于加载所有模型适配器,并负责管理已加载适配器的引用计数。

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateAdapters(const OgaModel* model, OgaAdapters** out);

参数

  • model: 已创建的 OgaModel 对象

结果

  • out: 指向已创建的 OgaAdapters 列表的引用

加载适配器

从给定的适配器文件路径和适配器名称加载模型适配器。

OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAdapter(OgaAdapters* adapters, const char* adapter_file_path, const char* adapter_name);

参数

  • adapters: 要加载适配器的 OgaAdapters 对象。
  • adapter_file_path: 要加载的适配器文件路径。
  • adapter_name: 用于查询适配器的唯一标识符

返回值

如果适配器加载失败,返回包含错误消息的 OgaResult

卸载适配器

从先前加载的适配器集合中卸载具有给定标识符的适配器。如果找不到适配器,或者无法卸载(当它正在使用时),则返回错误。

OGA_EXPORT OgaResult* OGA_API_CALL OgaUnloadAdapter(OgaAdapters* adapters, const char* adapter_name);

参数

  • adapters: 要从其卸载适配器的 OgaAdapters 对象。
  • adapter_name: 要卸载的适配器名称。

返回值

如果适配器卸载失败,返回包含错误消息的 OgaResult。如果调用此方法时,适配器尚未加载或正被仍在使用的 OgaGenerator 标记为活动状态,则可能会发生此情况。

设置活动适配器

将具有给定适配器名称的适配器设置为给定 OgaGenerator 对象的活动适配器。

OGA_EXPORT OgaResult* OGA_API_CALL OgaSetActiveAdapter(OgaGenerator* generator, OgaAdapters* adapters, const char* adapter_name);

参数

  • generator: 要设置活动适配器的 OgaGenerator 对象。
  • adapters: 管理模型适配器的 OgaAdapters 对象。
  • adapter_name: 要设置为活动的适配器名称。

返回值

如果适配器无法设置为活动状态,返回包含错误消息的 OgaResult。如果调用此方法时,适配器尚未加载,则可能会发生此情况。

枚举和结构体

typedef enum OgaDataType {
  OgaDataType_int32,
  OgaDataType_float32,
  OgaDataType_string,  // UTF8 string
} OgaDataType;
typedef struct OgaResult OgaResult;
typedef struct OgaGeneratorParams OgaGeneratorParams;
typedef struct OgaGenerator OgaGenerator;
typedef struct OgaModel OgaModel;
typedef struct OgaBuffer OgaBuffer;

实用函数

设置 GPU 设备 ID

OGA_EXPORT OgaResult* OGA_API_CALL OgaSetCurrentGpuDeviceId(int device_id);

获取 GPU 设备 ID

OGA_EXPORT OgaResult* OGA_API_CALL OgaGetCurrentGpuDeviceId(int* device_id);

获取错误消息

参数

  • 输入: result 包含错误消息的 OgaResult。

返回值

OgaResult 中包含的错误消息。const char* 由 OgaResult 拥有,并在 OgaResult 销毁时释放。

OGA_EXPORT const char* OGA_API_CALL OgaResultGetError(OgaResult* result);

销毁结果对象

参数

  • 输入: result 要销毁的 OgaResult。

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyResult(OgaResult*);

销毁字符串

参数

  • 输入: 要销毁的字符串

返回值

OGA_EXPORT void OGA_API_CALL OgaDestroyString(const char*);

销毁缓冲区

参数

  • 输入: 要销毁的缓冲区

返回值

void

OGA_EXPORT void OGA_API_CALL OgaDestroyBuffer(OgaBuffer*);

获取缓冲区类型

参数

  • 输入: 缓冲区

返回值

缓冲区的类型

OGA_EXPORT OgaDataType OGA_API_CALL OgaBufferGetType(const OgaBuffer*);

获取缓冲区维度数量

参数

  • 输入: 缓冲区

返回值

缓冲区中的维度数量

OGA_EXPORT size_t OGA_API_CALL OgaBufferGetDimCount(const OgaBuffer*);

获取缓冲区维度

获取缓冲区的维度

参数

  • 输入: 缓冲区
  • 输出:维度数组

返回值

OgaResult

OGA_EXPORT OgaResult* OGA_API_CALL OgaBufferGetDims(const OgaBuffer*, size_t* dims, size_t dim_count);

获取缓冲区数据

获取缓冲区的数据

参数

返回值

void

OGA_EXPORT const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer*);

创建序列

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out);

销毁序列

参数

  • 输入:要销毁的序列 OgaSequences。

返回值

void

返回值

OGA_EXPORT void OGA_API_CALL OgaDestroySequences(OgaSequences* sequences);

获取序列数量

返回 OgaSequences 中的序列数量

参数

  • 输入:序列

返回值

OgaSequences 中的序列数量

OGA_EXPORT size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* sequences);

获取序列中的 token 数量

返回给定索引处序列中的 token 数量

参数

  • 输入:序列

返回值

给定索引处序列中的 token 数量

OGA_EXPORT size_t OGA_API_CALL OgaSequencesGetSequenceCount(const OgaSequences* sequences, size_t sequence_index);

获取序列数据

返回指向给定索引处序列数据的指针。序列中的 token 数量由 OgaSequencesGetSequenceCount 提供

参数

  • 输入:序列

返回值

指向给定索引处序列数据的指针。该指针在 OgaSequences 被销毁之前有效。

OGA_EXPORT const int32_t* OGA_API_CALL OgaSequencesGetSequenceData(const OgaSequences* sequences, size_t sequence_index);