使用 C# 的推理 BERT NLP 深度学习和 ONNX Runtime
在本教程中,我们将学习如何在 C# 中对流行的 BERT 自然语言处理深度学习模型进行推理。
为了能够在 C# 中预处理我们的文本,我们将利用开源的 BERTTokenizers,其中包括大多数 BERT 模型的 tokenizer。请参阅下文了解支持的模型。
- BERT Base
- BERT Large
- BERT German
- BERT Multilingual
- BERT Base Uncased
- BERT Large Uncased
有许多模型(包括本教程中的模型)是基于这些基本模型进行微调的。模型的 tokenizer 仍然与微调所基于的基础模型相同。
目录
先决条件
本教程可以在本地运行,也可以通过利用 Azure 机器学习计算运行。
在本地运行
在云端使用 Azure 机器学习运行
使用 Hugging Face 下载 BERT 模型
Hugging Face 具有出色的 API,用于下载开源模型,然后我们可以使用 python 和 Pytorch 将它们导出为 ONNX 格式。当使用尚未成为 ONNX 模型动物园一部分的开源模型时,这是一个很好的选择。
在 Python 中下载和导出模型的步骤
使用 transformers
API 下载名为 bert-large-uncased-whole-word-masking-finetuned-squad
的 BertForQuestionAnswering
模型
import torch
from transformers import BertForQuestionAnswering
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
model_path = "./" + model_name + ".onnx"
model = BertForQuestionAnswering.from_pretrained(model_name)
# set the model to inference mode
# It is important to call torch_model.eval() or torch_model.train(False) before exporting the model,
# to turn the model to inference mode. This is required since operators like dropout or batchnorm
# behave differently in inference and training mode.
model.eval()
现在我们已经下载了模型,我们需要将其导出为 ONNX
格式。这内置于 Pytorch 中,通过 torch.onnx.export
函数。
-
inputs
变量指示输入形状。您可以创建如下所示的虚拟输入,或使用来自模型测试的示例输入。 -
将
opset_version
设置为模型支持的最高兼容版本。了解有关 opset 版本的更多信息 此处。 -
为模型设置
input_names
和output_names
。 -
为动态长度输入设置
dynamic_axes
,因为对于每个问题推理,sentence
和context
变量的长度将不同。
# Generate dummy inputs to the model. Adjust if necessary.
inputs = {
# list of numerical ids for the tokenized text
'input_ids': torch.randint(32, [1, 32], dtype=torch.long),
# dummy list of ones
'attention_mask': torch.ones([1, 32], dtype=torch.long),
# dummy list of ones
'token_type_ids': torch.ones([1, 32], dtype=torch.long)
}
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(model,
# model being run
(inputs['input_ids'],
inputs['attention_mask'],
inputs['token_type_ids']), # model input (or a tuple for multiple inputs)
model_path, # where to save the model (can be a file or file-like object)
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input_ids',
'input_mask',
'segment_ids'], # the model's input names
output_names=['start_logits', "end_logits"], # the model's output names
dynamic_axes={'input_ids': symbolic_names,
'input_mask' : symbolic_names,
'segment_ids' : symbolic_names,
'start_logits' : symbolic_names,
'end_logits': symbolic_names}) # variable length axes/dynamic input
了解 Python 中的模型
当采用预构建模型并将其投入运行时,花一些时间了解模型的预处理和后处理以及输入/输出形状和标签非常有用。许多模型都提供了 Python 中的示例代码。我们将使用 C# 推理我们的模型,但首先让我们测试一下,看看如何在 Python 中完成。这将有助于我们下一步的 C# 逻辑。
-
测试模型的代码在 本教程 中提供。查看在 Python 中测试和推理此模型的源代码。以下是运行模型的示例
input
句子和示例output
。 -
示例
input
input = "{\"question\": \"What is Dolly Parton's middle name?\", \"context\": \"Dolly Rebecca Parton is an American singer-songwriter\"}"
print(run(input))
- 这是上述问题的输出结果。您可以使用
input_ids
来验证 C# 中的 tokenization。
Output:
{'input_ids': [101, 2054, 2003, 19958, 2112, 2239, 1005, 1055, 2690, 2171, 1029, 102, 19958, 9423, 2112, 2239, 2003, 2019, 2137, 3220, 1011, 6009, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'answer': 'Rebecca'}
使用 C# 进行推理
现在我们已经在 Python 中测试了模型,是时候在 C# 中构建它了。我们需要做的第一件事是创建我们的项目。在此示例中,我们将使用控制台应用程序,但是您可以在任何 C# 应用程序中使用此代码。
- 打开 Visual Studio 并 创建控制台应用程序
安装 Nuget 包
- 安装 Nuget 包
BERTTokenizers
、Microsoft.ML.OnnxRuntime
、Microsoft.ML.OnnxRuntime.Managed
、Microsoft.ML
dotnet add package Microsoft.ML.OnnxRuntime --version 1.16.0 dotnet add package Microsoft.ML.OnnxRuntime.Managed --version 1.16.0 dotnet add package Microsoft.ML dotnet add package BERTTokenizers --version 1.1.0
创建应用程序
- 导入包
using BERTTokenizers;
using Microsoft.ML.Data;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
- 添加
namespace
、class
和Main
函数。
namespace MyApp // Note: actual namespace depends on the project name.
{
internal class BertTokenizeProgram
{
static void Main(string[] args)
{
}
}
}
创建用于编码的 BertInput 类
- 添加
BertInput
结构
public struct BertInput
{
public long[] InputIds { get; set; }
public long[] AttentionMask { get; set; }
public long[] TypeIds { get; set; }
}
使用 BertUncasedLargeTokenizer
对句子进行 tokenization
- 创建一个句子(问题和上下文)并使用
BertUncasedLargeTokenizer
对句子进行 tokenization。基本模型是bert-large-uncased
,因此我们使用库中的BertUncasedLargeTokenizer
。请务必检查您的 BERT 模型的基础模型是什么,以确认您使用的是正确的 tokenizer。
var sentence = "{\"question\": \"Where is Bob Dylan From?\", \"context\": \"Bob Dylan is from Duluth, Minnesota and is an American singer-songwriter\"}";
Console.WriteLine(sentence);
// Create Tokenizer and tokenize the sentence.
var tokenizer = new BertUncasedLargeTokenizer();
// Get the sentence tokens.
var tokens = tokenizer.Tokenize(sentence);
// Console.WriteLine(String.Join(", ", tokens));
// Encode the sentence and pass in the count of the tokens in the sentence.
var encoded = tokenizer.Encode(tokens.Count(), sentence);
// Break out encoding to InputIds, AttentionMask and TypeIds from list of (input_id, attention_mask, type_id).
var bertInput = new BertInput()
{
InputIds = encoded.Select(t => t.InputIds).ToArray(),
AttentionMask = encoded.Select(t => t.AttentionMask).ToArray(),
TypeIds = encoded.Select(t => t.TokenTypeIds).ToArray(),
};
创建推理所需的 name -> OrtValue
对的 inputs
- 获取模型,在输入缓冲区之上创建 3 个 OrtValue,并将它们包装到 Dictionary 中以馈送到 Run()。请注意,几乎所有的 Onnxruntime 类都包装了本机数据结构,因此,必须进行释放以防止内存泄漏。
// Get path to model to create inference session.
var modelPath = @"C:\code\bert-nlp-csharp\BertNlpTest\BertNlpTest\bert-large-uncased-finetuned-qa.onnx";
using var runOptions = new RunOptions();
using var session = new InferenceSession(modelPath);
// Create input tensors over the input data.
using var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
new long[] { 1, bertInput.InputIds.Length });
using var attMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.AttentionMask,
new long[] { 1, bertInput.AttentionMask.Length });
using var typeIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.TypeIds,
new long[] { 1, bertInput.TypeIds.Length });
// Create input data for session. Request all outputs in this case.
var inputs = new Dictionary<string, OrtValue>
{
{ "input_ids", inputIdsOrtValue },
{ "input_mask", attMaskOrtValue },
{ "segment_ids", typeIdsOrtValue }
};
运行推理
- 创建
InferenceSession
,运行推理并打印出结果。
// Run session and send the input data in to get inference output.
using var output = session.Run(runOptions, inputs, session.OutputNames);
后处理 output
并打印结果
- 在这里,我们获取起始位置 (
startLogit
) 和结束位置 (endLogits
) 的索引。然后,我们获取输入句子的原始tokens
,并获取预测的 token id 的词汇表值。
// Get the Index of the Max value from the output lists.
// We intentionally do not copy to an array or to a list to employ algorithms.
// Hopefully, more algos will be available in the future for spans.
// so we can directly read from native memory and do not duplicate data that
// can be large for some models
// Local function
int GetMaxValueIndex(ReadOnlySpan<float> span)
{
float maxVal = span[0];
int maxIndex = 0;
for (int i = 1; i < span.Length; ++i)
{
var v = span[i];
if (v > maxVal)
{
maxVal = v;
maxIndex = i;
}
}
return maxIndex;
}
var startLogits = output[0].GetTensorDataAsSpan<float>();
int startIndex = GetMaxValueIndex(startLogits);
var endLogits = output[output.Count - 1].GetTensorDataAsSpan<float>();
int endIndex = GetMaxValueIndex(endLogits);
var predictedTokens = tokens
.Skip(startIndex)
.Take(endIndex + 1 - startIndex)
.Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
.ToList();
// Print the result.
Console.WriteLine(String.Join(" ", predictedTokens));
使用 Azure Web App 部署
在本示例中,我们创建了一个简单的控制台应用程序,但是这可以很容易地在 C# Web 应用程序之类的东西中实现。查看有关如何 快速入门:部署 ASP.NET Web 应用程序 的文档。
后续步骤
有许多不同的 BERT 模型已经针对不同的任务和不同的基础模型进行了微调,您可以针对您的特定任务进行微调。此代码适用于大多数 BERT 模型,只需更新您的特定模型的输入、输出和预处理/后处理即可。