在 AzureML 上使用 ONNX Runtime 部署高性能问答模型

本教程从 HuggingFace 获取 BERT 模型,将其转换为 ONNX,并通过 AzureML 使用 ONNX Runtime 部署 ONNX 模型。

在以下部分中,我们使用使用斯坦福问答数据集 (SQuAD) 数据集训练的 HuggingFace BERT 模型作为示例。您也可以训练或微调您自己的问答模型。

问答场景接收一个问题和一段称为上下文的文本,并生成答案,答案是从上下文中提取的文本字符串。此场景对问题和上下文进行标记化和编码,将输入馈送到 Transformer 模型中,并通过生成上下文中可能性最高的开始和结束标记来生成答案,然后将这些标记映射回单词。

Example question and context showing major processing units of tokenizer, BERT model, and post processing to extract indices of max start and end probabilities to produce the answer

然后,模型和评分代码使用在线终结点部署在 AzureML 上。

目录

先决条件

本教程的源代码 发布在 GitHub 上。

要在 AzureML 上运行,您需要

  • Azure 订阅
  • Azure 机器学习工作区(如果您还没有工作区,请参阅 AzureML 配置笔记本 以了解如何创建工作区)
  • Azure 机器学习 SDK
  • Azure CLI 和 Azure 机器学习 CLI 扩展(> 版本 2.2.2)

您可能还会发现以下资源很有用

如果您无权访问 AzureML 订阅,则可以在本地运行本教程。

环境

要直接安装依赖项,请运行以下命令

pip install torch
pip install transformers
pip install azureml azureml.core
pip install onnxruntime
pip install matplotlib

要从您的 conda 环境创建一个 Jupyter 内核,请运行以下命令。替换为您的内核名称。

conda install -c anaconda ipykernel
python -m ipykernel install --user --name=<kernel name>

安装 AzureML CLI 扩展,该扩展在以下部署步骤中使用

az login
az extension add --name ml
# Remove the azure-cli-ml extension if it is installed, as it is not compatible with the az ml extension
az extension remove azure-cli-ml

获取 PyTorch 模型并将其转换为 ONNX 格式

在下面的代码中,我们从 HuggingFace 获取一个针对使用 SQUAD 数据集进行问答微调的 BERT 模型。

如果您想从头开始预训练 BERT 模型,请按照 预训练 BERT 模型 中的说明进行操作。如果您想使用自己的数据集微调模型,请参考 AzureML BERT Eval SquadAzureML BERT Eval GLUE

导出模型

使用 PyTorch ONNX 导出器创建 ONNX 格式的模型,以便与 ONNX Runtime 一起运行。

import torch
from transformers import BertForQuestionAnswering

model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
model_path = "./" + model_name + ".onnx"
model = BertForQuestionAnswering.from_pretrained(model_name)

# set the model to inference mode
# It is important to call torch_model.eval() or torch_model.train(False) before exporting the model, 
# to turn the model to inference mode. This is required since operators like dropout or batchnorm 
# behave differently in inference and training mode.
model.eval()

# Generate dummy inputs to the model. Adjust if necessary
inputs = {
        'input_ids':   torch.randint(32, [1, 32], dtype=torch.long), # list of numerical ids for the tokenized text
        'attention_mask': torch.ones([1, 32], dtype=torch.long),     # dummy list of ones
        'token_type_ids':  torch.ones([1, 32], dtype=torch.long)     # dummy list of ones
    }

symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(model,                                         # model being run
                  (inputs['input_ids'],
                   inputs['attention_mask'], 
                   inputs['token_type_ids']),                    # model input (or a tuple for multiple inputs)
                  model_path,                                    # where to save the model (can be a file or file-like object)
                  opset_version=11,                              # the ONNX version to export the model to
                  do_constant_folding=True,                      # whether to execute constant folding for optimization
                  input_names=['input_ids',
                               'input_mask', 
                               'segment_ids'],                   # the model's input names
                  output_names=['start_logits', "end_logits"],   # the model's output names
                  dynamic_axes={'input_ids': symbolic_names,
                                'input_mask' : symbolic_names,
                                'segment_ids' : symbolic_names,
                                'start_logits' : symbolic_names, 
                                'end_logits': symbolic_names})   # variable length axes

使用 ONNX Runtime 运行 ONNX 模型

以下代码使用 ONNX Runtime 运行 ONNX 模型。您可以在将其部署到 Azure 机器学习之前在本地对其进行测试。

init() 函数在启动时调用,执行一次性操作,例如创建分词器和 ONNX Runtime 会话。

当我们使用 Azure ML 终结点运行模型时,将调用 run() 函数。添加必要的 preprocess() 和 postprocess() 步骤。

为了进行本地测试和比较,您也可以运行 PyTorch 模型。

import os
import logging
import json
import numpy as np
import onnxruntime
import transformers
import torch

# The pre process function take a question and a context, and generates the tensor inputs to the model:
# - input_ids: the words in the question encoded as integers
# - attention_mask: not used in this model
# - token_type_ids: a list of 0s and 1s that distinguish between the words of the question and the words of the context
# This function also returns the words contained in the question and the context, so that the answer can be decoded into a phrase. 
def preprocess(question, context):
    encoded_input = tokenizer(question, context)
    tokens = tokenizer.convert_ids_to_tokens(encoded_input.input_ids)
    return (encoded_input.input_ids, encoded_input.attention_mask, encoded_input.token_type_ids, tokens)

# The post process function maps the list of start and end log probabilities onto a text answer, using the text tokens from the question
# and context. 
def postprocess(tokens, start, end):
    results = {}
    answer_start = np.argmax(start)
    answer_end = np.argmax(end)
    if answer_end >= answer_start:
        answer = tokens[answer_start]
        for i in range(answer_start+1, answer_end+1):
            if tokens[i][0:2] == "##":
                answer += tokens[i][2:]
            else:
                answer += " " + tokens[i]
        results['answer'] = answer.capitalize()
    else:
        results['error'] = "I am unable to find the answer to this question. Can you please ask another question?"
    return results

# Perform the one-off initialization for the prediction. The init code is run once when the endpoint is setup.
def init():
    global tokenizer, session, model

    model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
    model = transformers.BertForQuestionAnswering.from_pretrained(model_name)

    # use AZUREML_MODEL_DIR to get your deployed model(s). If multiple models are deployed, 
    # model_path = os.path.join(os.getenv('AZUREML_MODEL_DIR'), '$MODEL_NAME/$VERSION/$MODEL_FILE_NAME')
    model_dir = os.getenv('AZUREML_MODEL_DIR')
    if model_dir == None:
        model_dir = "./"
    model_path = os.path.join(model_dir, model_name + ".onnx")

    # Create the tokenizer
    tokenizer = transformers.BertTokenizer.from_pretrained(model_name)

    # Create an ONNX Runtime session to run the ONNX model
    session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])  


# Run the PyTorch model, for functional and performance comparison
def run_pytorch(raw_data):
    inputs = json.loads(raw_data)

    model.eval()

    logging.info("Question:", inputs["question"])
    logging.info("Context: ", inputs["context"])

    input_ids, input_mask, segment_ids, tokens = preprocess(inputs["question"], inputs["context"])
    model_outputs = model(torch.tensor([input_ids]),  token_type_ids=torch.tensor([segment_ids]))
    return postprocess(tokens, model_outputs.start_logits.detach().numpy(), model_outputs.end_logits.detach().numpy())

# Run the ONNX model with ONNX Runtime
def run(raw_data):
    logging.info("Request received")
    inputs = json.loads(raw_data)
    logging.info(inputs)

    # Preprocess the question and context into tokenized ids
    input_ids, input_mask, segment_ids, tokens = preprocess(inputs["question"], inputs["context"])
    logging.info("Running inference")
    
    # Format the inputs for ONNX Runtime
    model_inputs = {
        'input_ids':   [input_ids], 
        'input_mask':  [input_mask],
        'segment_ids': [segment_ids]
        }
                  
    outputs = session.run(['start_logits', 'end_logits'], model_inputs)
    logging.info("Post-processing")  

    # Post process the output of the model into an answer (or an error if the question could not be answered)
    results = postprocess(tokens, outputs[0], outputs[1])
    logging.info(results)
    return results


if __name__ == '__main__':
    init()

    input = "{\"question\": \"What is Dolly Parton's middle name?\", \"context\": \"Dolly Rebecca Parton is an American singer-songwriter\"}"

    run_pytorch(input)
    print(run(input))

通过 AzureML 使用 ONNX Runtime 部署模型

现在我们有了 ONNX 模型以及使用 ONNX Runtime 运行它的代码,我们可以使用 Azure ML 部署它。

Component diagram showing AzureML deployment with ONNX Runtime including environment, ONNX model and scoring code

检查您的环境

import azureml.core
import onnxruntime
import torch
import transformers

print("Transformers version: ", transformers.__version__)
torch_version = torch.__version__
print("Torch (ONNX exporter) version: ", torch_version)
print("Azure SDK version:", azureml.core.VERSION)
print("ONNX Runtime version: ", onnxruntime.__version__)

加载您的 Azure ML 工作区

我们首先从配置笔记本中先前创建的现有工作区实例化一个工作区对象。

请注意,以下代码假定您在笔记本的同一目录或名为 .azureml 的子目录中有一个包含订阅信息的 config.json 文件。您也可以使用 Workspace.get() 方法显式提供工作区名称、订阅名称和资源组。

import os
from azureml.core import Workspace

ws = Workspace.from_config()
print(ws.name, ws.location, ws.resource_group, ws.subscription_id, sep = '\n')
Register your model with Azure ML
Now we upload the model and register it in the workspace.

from azureml.core.model import Model

model = Model.register(model_path = model_path,                 # Name of the registered model in your workspace.
                       model_name = model_name,            # Local ONNX model to upload and register as a model
                       model_framework=Model.Framework.ONNX ,   # Framework used to create the model.
                       model_framework_version=torch_version,   # Version of ONNX used to create the model.
                       tags = {"onnx": "demo"},
                       description = "HuggingFace BERT model fine-tuned with SQuAd and exported from PyTorch",
                       workspace = ws)

显示您注册的模型

您可以列出您在此工作区中注册的所有模型。

models = ws.models
for name, m in models.items():
    print("Name:", name,"\tVersion:", m.version, "\tDescription:", m.description, m.tags)
    
#     # If you'd like to delete the models from workspace
#     model_to_delete = Model(ws, name)
#     model_to_delete.delete()

将模型和评分代码部署为 AzureML 终结点

注意:Python SDK 的终结点接口尚未公开发布,因此在本节中,我们将使用 Azure ML CLI。

yml 文件夹中有三个 YML 文件

  • env.yml:conda 环境规范,将从中生成终结点的执行环境
  • endpoint.yml:终结点规范,仅包含终结点的名称和授权方法
  • deployment.yml:部署规范,其中包含评分代码、模型和环境的规范。您可以为每个终结点创建多个部署,并将不同数量的流量路由到这些部署。对于本示例,我们将仅创建一个部署。

部署可能需要长达 15 分钟。另请注意,笔记本目录中的所有文件都将上传到构成终结点基础的 Docker 容器中,包括 ONNX 模型的任何本地副本(该模型已在上一步中部署到 AzureML)。为了减少部署时间,请在创建终结点之前删除任何大型文件的本地副本。

az ml online-endpoint create --name question-answer-ort --file yml/endpoint.yml --subscription {ws.subscription_id} --resource-group {ws.resource_group} --workspace-name {ws.name} 
az ml online-deployment create --endpoint-name question-answer-ort --name blue --file yml/deployment.yml --all-traffic --subscription {ws.subscription_id} --resource-group {ws.resource_group} --workspace-name {ws.name} 

测试已部署的终结点

以下命令运行已部署的问答模型。test-data.json 文件中有一个测试问题。您可以使用自己的问题和上下文编辑此文件。

az ml online-endpoint invoke --name question-answer-ort --request-file test-data.json --subscription {ws.subscription_id} --resource-group {ws.resource_group} --workspace-name {ws.name} 

如果您已完成到此步骤,则说明您已部署了一个使用 ONNX 模型回答问题的有效终结点。

您可以提供自己的问题和上下文来回答问题!

清理 Azure 资源

以下命令删除您已部署的 AzureML 终结点。您可能还希望清理您的 AzureML 工作区、计算和注册的模型。

az ml online-endpoint delete --name question-answer-ort --yes --subscription {ws.subscription_id} --resource-group {ws.resource_group} --workspace-name {ws.name}