迁移 ONNX Runtime generate() API 从 0.5.2 到 0.6.0
了解如何从 ONNX Runtime generate() 版本 0.5.2 迁移到版本 0.6.0。
版本 0.6.0 增加了对“聊天模式”的支持,也称为续写、连续解码和交互式解码。随着聊天模式的引入,API 发生了重大更改。
总而言之,新 API 在 Generator 中添加了一个 AppendTokens 方法,该方法允许进行多轮对话。以前,输入是在创建 Generator 之前在 GeneratorParams 中设置的。
在对话循环之外调用 AppendTokens 可以用于实现系统提示缓存。
注意:聊天模式和系统提示缓存仅支持批量大小为 1。此外,它们目前在 CPU、使用 CUDA EP 的 NVIDIA GPU 以及所有使用 Web GPU 原生 EP 的 GPU 上受支持。它们在 NPU 或使用 DirecML EP 运行的 GPU 上不受支持。对于问答(Q&A)模式,仍然需要下面描述的迁移。
Python
迁移 Python 问答(单轮)代码到 0.6.0
- 在创建 generator 对象后,将调用
params.input_ids = input_tokens
替换为generator.append_tokens(input_tokens)
。 - 删除对
generator.compute_logits()
的调用 - 如果应用程序有问答循环,请在
append_token
调用之间删除generator
以重置模型的状态。
向 Python 应用程序添加系统提示缓存
-
创建系统提示并对其进行标记化,然后调用
generator.append_tokens(system_tokens)
。此调用可以在询问用户提示之前完成。system_tokens = tokenizer.encode(system_prompt) generator.append_tokens(system_tokens)
向 Python 应用程序添加聊天模式
-
在您的应用程序中创建一个循环,并在用户每次提供新输入时调用
generator.append_tokens(prompt)
while True: user_input = input("Input: ") input_tokens = tokenizer.encode(user_input) generator.append_tokens(input_tokens) while not generator.is_done(): generator.generate_next_token() new_token = generator.get_next_tokens()[0] print(tokenizer_stream.decode(new_token), end='', flush=True) except KeyboardInterrupt: print()
C++
迁移 C++ 问答(单轮)代码到 0.6.0
- 将调用
params->SetInputSequences(*sequences)
替换为generator->AppendTokenSequences(*sequences)
- 删除对
generator->ComputeLogits()
的调用
向 C++ 应用程序添加系统提示缓存
-
创建系统提示并对其进行标记化,然后调用
generator->AppendTokenSequences(*sequences)
。此调用可以在询问用户提示之前完成。auto sequences = OgaSequences::Create(); tokenizer->Encode(system_prompt.c_str(), *sequences); generator->AppendTokenSequences(*sequences); generator.append_tokens(system_tokens)
向您的 C++ 应用程序添加聊天模式
- 向您的应用程序添加聊天循环
std::cout << "Generating response..." << std::endl; auto params = OgaGeneratorParams::Create(*model); params->SetSearchOption("max_length", 1024); auto generator = OgaGenerator::Create(*model, *params); while (true) { std::string text; std::cout << "Prompt: " << std::endl; std::getline(std::cin, prompt); auto sequences = OgaSequences::Create(); tokenizer->Encode(prompt.c_str(), *sequences); generator->AppendTokenSequences(*sequences); while (!generator->IsDone()) { generator->GenerateNextToken(); const auto num_tokens = generator->GetSequenceCount(0); const auto new_token = generator->GetSequenceData(0)[num_tokens - 1]; std::cout << tokenizer_stream->Decode(new_token) << std::flush; } }
C#
迁移 C# 问答(单轮)代码到 0.6.0
- 将调用
generatorParams.SetInputSequences(sequences)
替换为generator.AppendTokenSequences(sequences)
- 删除对
generator.ComputeLogits()
的调用
向您的 C# 应用程序添加系统提示缓存
-
创建系统提示并对其进行标记化,然后调用
generator->AppendTokenSequences()
。此调用可以在询问用户提示之前完成。var systemPrompt = "..." auto sequences = OgaSequences::Create(); tokenizer->Encode(systemPrompt, *sequences); generator->AppendTokenSequences(*sequences);
向您的 C# 应用程序添加聊天模式
- 向您的应用程序添加聊天循环
using var tokenizerStream = tokenizer.CreateStream(); using var generator = new Generator(model, generatorParams); Console.WriteLine("Prompt:"); prompt = Console.ReadLine(); // Example Phi-3 template var sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); do { generator.AppendTokenSequences(sequences); var watch = System.Diagnostics.Stopwatch.StartNew(); while (!generator.IsDone()) { generator.GenerateNextToken(); Console.Write(tokenizerStream.Decode(generator.GetSequence(0)[^1])); } Console.WriteLine(); watch.Stop(); var runTimeInSeconds = watch.Elapsed.TotalSeconds; var outputSequence = generator.GetSequence(0); var totalTokens = outputSequence.Length; Console.WriteLine($"Streaming Tokens: {totalTokens} Time: {runTimeInSeconds:0.00} Tokens per second: {totalTokens / runTimeInSeconds:0.00}"); Console.WriteLine("Next prompt:"); var nextPrompt = Console.ReadLine(); sequences = tokenizer.Encode($"<|user|>{prompt}<|end|><|assistant|>"); } while (prompt != null);
Java
即将推出