使用 C# 和 ONNX Runtime 推理 Stable Diffusion

在本教程中,我们将学习如何在 C# 中对流行的 Stable Diffusion 深度学习模型进行推理。Stable Diffusion 模型接受文本提示并创建表示该文本的图像。请参阅下面的示例

"make a picture of green tree with flowers around it and a red sky" 
Image of browser inferencing on sample images. Image of browser inferencing on sample images.

目录

先决条件

本教程可以在本地运行,也可以通过利用 Azure Machine Learning 计算在云中运行。

在本地运行

使用 Azure Machine Learning 在云中运行

使用 Hugging Face 下载 Stable Diffusion 模型

Hugging Face 网站拥有丰富的开源模型库。我们将利用并下载来自 Hugging Face 的 ONNX Stable Diffusion 模型

选择模型版本仓库后,点击Files and Versions,然后选择ONNX分支。如果没有可用的 ONNX 模型分支,请使用main分支并将其转换为 ONNX。有关更多信息,请参阅 PyTorch 的 ONNX 转换教程

  • 克隆仓库
    git lfs install
    git clone https://hugging-face.cn/CompVis/stable-diffusion-v1-4 -b onnx
    
  • 将包含 ONNX 文件的文件夹复制到 C# 项目文件夹 \StableDiffusion\StableDiffusion。需要复制的文件夹有:unetvae_decodertext_encodersafety_checker

使用 Hugging Face 的 Diffusers 在 Python 中理解模型

在将预构建模型投入实际使用时,花点时间理解此管道中的模型非常有用。此代码基于 Hugging Face Diffusers 库和博客。如果您想了解更多关于其工作原理的信息,请查阅这篇精彩的博客文章

使用 C# 推理

现在让我们开始分解如何在 C# 中进行推理!unet 模型接收用户提示的文本嵌入,该嵌入由连接文本和图像的 CLIP 模型创建。潜在的噪声图像被创建为起点。调度器算法和 unet 模型协同工作,对图像进行去噪,以创建表示文本提示的图像。让我们看看代码。

主函数

主函数设置提示、推理步数和指导尺度。然后调用 UNet.Inference 函数运行推理。

需要设置的属性是

  • prompt - 用于图像的文本提示
  • num_inference_steps - 运行推理的步数。步数越多,推理循环运行时间越长,但图像质量应该会提高。
  • guidance_scale - 无分类器指导的尺度。数字越高,越会尝试看起来像提示,但图像质量可能会下降。
  • batch_size - 创建图像的数量
  • height - 图像高度。默认值为 512,必须是 8 的倍数。
  • width - 图像宽度。默认值为 512,必须是 8 的倍数。

* 注意:请查阅 Hugging Face 博客以获取更多详细信息。

//Default args
var prompt = "make a picture of green tree with flowers around it and a red sky";
// Number of steps
var num_inference_steps = 10;

// Scale for classifier-free guidance
var guidance_scale = 7.5;
//num of images requested
var batch_size = 1;
// Load the tokenizer and text encoder to tokenize and encodethe text.
var textTokenized = TextProcessing.TokenizeText(prompt);
var textPromptEmbeddings = TextProcessing.TextEncode(textTokenized).ToArray();
// Create uncond_input of blank tokens
var uncondInputTokens = TextProcessing.CreateUncondInput();
var uncondEmbedding = TextProcessing.TextEncode(uncondInputTokens).ToArray();
// Concat textEmeddings and uncondEmbedding
DenseTensor<float> textEmbeddings = new DenseTensor<float>(ne[] { 2, 77, 768 });
for (var i = 0; i < textPromptEmbeddings.Length; i++)
{
    textEmbeddings[0, i / 768, i % 768] = uncondEmbedding[i];
    textEmbeddings[1, i / 768, i % 768] = textPromptEmbeddings[i];
}
var height = 512;
var width = 512;
// Inference Stable Diff
var image = UNet.Inference(num_inference_steps, textEmbeddings,guidance_scale, batch_size, height, width);
// If image failed or was unsafe it will return null.
if( image == null )
{
    Console.WriteLine("Unable to create image, please try again.");
}

使用 ONNX Runtime Extensions 进行分词

TextProcessing 类包含用于对文本提示进行分词并使用 CLIP 模型文本编码器对其进行编码的函数。

我们不必在 C# 中重新实现 CLIP 分词器,而是可以利用 ONNX Runtime Extensions 中跨平台的 CLIP 分词器实现。ONNX Runtime Extensions 有一个 custom_op_cliptok.onnx 文件分词器,用于对文本提示进行分词。该分词器是一个简单的分词器,将文本分割成单词,然后将单词转换为 token。

  • 文本提示:一个句子或短语,代表您想要创建的图像。
    make a picture of green tree with flowers aroundit and a red sky
    
  • 文本分词:文本提示被分词成 token 列表。每个 token ID 是一个表示句子中单词的数字,然后用空白 token 填充以创建最大长度为 77 的 token。然后将 token ID 转换为形状为 (1, 77) 的张量。

  • 下面是使用 ONNX Runtime Extensions 对文本提示进行分词的代码。
public static int[] TokenizeText(string text)
{
            // Create Tokenizer and tokenize the sentence.
            var tokenizerOnnxPath = Directory.GetCurrentDirectory().ToString() + ("\\text_tokenizer\\custom_op_cliptok.onnx");

            // Create session options for custom op of extensions
            using var sessionOptions = new SessionOptions();
            var customOp = "ortextensions.dll";
            sessionOptions.RegisterCustomOpLibraryV2(customOp, out var libraryHandle);

            // Create an InferenceSession from the onnx clip tokenizer.
            using var tokenizeSession = new InferenceSession(tokenizerOnnxPath, sessionOptions);

            // Create input tensor from text
            using var inputTensor = OrtValue.CreateTensorWithEmptyStrings(OrtAllocator.DefaultInstance, new long[] { 1 });
            inputTensor.StringTensorSetElementAt(text.AsSpan(), 0);

            var inputs = new Dictionary<string, OrtValue>
            {
                {  "string_input", inputTensor }
            };

            // Run session and send the input data in to get inference output. 
            using var runOptions = new RunOptions();
            using var tokens = tokenizeSession.Run(runOptions, inputs, tokenizeSession.OutputNames);

            var inputIds = tokens[0].GetTensorDataAsSpan<long>();

            // Cast inputIds to Int32
            var InputIdsInt = new int[inputIds.Length];
            for(int i = 0; i < inputIds.Length; i++)
            {
                InputIdsInt[i] = (int)inputIds[i];
            }

            Console.WriteLine(String.Join(" ", InputIdsInt));

            var modelMaxLength = 77;
            // Pad array with 49407 until length is modelMaxLength
            if (InputIdsInt.Length < modelMaxLength)
            {
                var pad = Enumerable.Repeat(49407, 77 - InputIdsInt.Length).ToArray();
                InputIdsInt = InputIdsInt.Concat(pad).ToArray();
            }
            return InputIdsInt;
}

tensor([[49406,  1078,   320,  1674,   539,  1901,  2677,   593,  4023,  1630,
           585,   537,   320,   736,  2390, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]])

使用 CLIP 文本编码器模型进行文本嵌入

这些 token 被发送到文本编码器模型并转换为形状为 (1, 77, 768) 的张量,其中第一维是批量大小,第二维是 token 数,第三维是嵌入大小。文本编码器是一个 OpenAI CLIP 模型,它将文本与图像连接起来。

文本编码器创建文本嵌入,该嵌入经过训练,将文本提示编码成一个向量,用于指导图像生成。然后将文本嵌入与 uncond 嵌入连接起来,创建发送到 unet 模型进行推理的文本嵌入。

  • 文本嵌入:一个数字向量,代表由分词结果创建的文本提示。文本嵌入由 text_encoder 模型创建。
        public static float[] TextEncoder(int[] tokenizedInput)
        {
            // Create input tensor. OrtValue will not copy, will read from managed memory
            using var input_ids = OrtValue.CreateTensorValueFromMemory<int>(tokenizedInput,
                new long[] { 1, tokenizedInput.Count() });

            var textEncoderOnnxPath = Directory.GetCurrentDirectory().ToString() + ("\\text_encoder\\model.onnx");

            using var encodeSession = new InferenceSession(textEncoderOnnxPath);

            // Pre-allocate the output so it goes to a managed buffer
            // we know the shape
            var lastHiddenState = new float[1 * 77 * 768];
            using var outputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(lastHiddenState, new long[] { 1, 77, 768 });

            string[] input_names = { "input_ids" };
            OrtValue[] inputs = { input_ids };

            string[] output_names = { encodeSession.OutputNames[0] };
            OrtValue[] outputs = { outputOrtValue };

            // Run inference.
            using var runOptions = new RunOptions();
            encodeSession.Run(runOptions, input_names, inputs, output_names, outputs);

            return lastHiddenState;
        }
torch.Size([1, 77, 768])
tensor([[[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
         [ 0.0520, -0.6046,  1.9268,  ..., -0.3985,  0.9645, -0.4424],
         [-0.8027, -0.4533,  1.7525,  ..., -1.0365,  0.6296,  1.0712],
         ...,
         [-0.6833,  0.3571, -1.1353,  ..., -1.4067,  0.0142,  0.3566],
         [-0.7049,  0.3517, -1.1524,  ..., -1.4381,  0.0090,  0.3777],
         [-0.6155,  0.4283, -1.1282,  ..., -1.4256, -0.0285,  0.3206]]],

推理循环:UNet 模型、Timesteps 和 LMS 调度器

调度器

调度器算法和 unet 模型协同工作,对图像进行去噪,以创建表示文本提示的图像。可以使用不同的调度器算法,要了解更多信息,请查看 Hugging Face 的这篇博客。在本例中,我们将使用 `LMSDiscreteScheduler,它是基于 HuggingFace scheduling_lms_discrete.py 创建的。

Timesteps

推理循环是运行调度器算法和 unet 模型的主要循环。循环运行的 timesteps 数量由调度器算法根据推理步数和其他参数计算得出。

对于本例,我们有 10 个推理步,计算出以下 timesteps

// Get path to model to create inference session.
var modelPath = Directory.GetCurrentDirectory().ToString() + ("\\unet\\model.onnx");
var scheduler = new LMSDiscreteScheduler();
var timesteps = scheduler.SetTimesteps(numInferenceSteps);
tensor([999., 888., 777., 666., 555., 444., 333., 222., 111.,   0.])

Latents

latents 是用于模型输入的噪声图像张量。它使用 GenerateLatentSample 函数创建形状为 (1, 4, 64, 64) 的随机张量。seed 可以设置为随机数或固定数。如果 seed 设置为固定数,则每次都会使用相同的潜在张量。这对于调试或您希望每次创建相同的图像时非常有用。

var seed = new Random().Next();
var latents = GenerateLatentSample(batchSize, height, width,seed, scheduler.InitNoiseSigma);

Image of browser inferencing on sample images.

推理循环

对于每个推理步,潜在图像被复制以创建形状为 (2, 4, 64, 64) 的张量,然后对其进行缩放并通过 unet 模型进行推理。输出张量 (2, 4, 64, 64) 被分割并应用指导。然后将结果张量作为去噪过程的一部分发送到 LMSDiscreteScheduler 步,并返回调度器步骤的结果张量,循环再次完成,直到达到 num_inference_steps

var modelPath = Directory.GetCurrentDirectory().ToString() + ("\\unet\\model.onnx");
var scheduler = new LMSDiscreteScheduler();
var timesteps = scheduler.SetTimesteps(numInferenceSteps);

var seed = new Random().Next();
var latents = GenerateLatentSample(batchSize, height, width, seed, scheduler.InitNoiseSigma);

// Create Inference Session
using var options = new SessionOptions();
using var unetSession = new InferenceSession(modelPath, options);

var latentInputShape = new int[] { 2, 4, height / 8, width / 8 };
var splitTensorsShape = new int[] { 1, 4, height / 8, width / 8 };

for (int t = 0; t < timesteps.Length; t++)
{
    // torch.cat([latents] * 2)
    var latentModelInput = TensorHelper.Duplicate(latents.ToArray(), latentInputShape);

    // Scale the input
    latentModelInput = scheduler.ScaleInput(latentModelInput, timesteps[t]);

    // Create model input of text embeddings, scaled latent image and timestep
    var input = CreateUnetModelInput(textEmbeddings, latentModelInput, timesteps[t]);

    // Run Inference
    using var output = unetSession.Run(input);
    var outputTensor = output[0].Value as DenseTensor<float>;

    // Split tensors from 2,4,64,64 to 1,4,64,64
    var splitTensors = TensorHelper.SplitTensor(outputTensor, splitTensorsShape);
    var noisePred = splitTensors.Item1;
    var noisePredText = splitTensors.Item2;

    // Perform guidance
    noisePred = performGuidance(noisePred, noisePredText, guidanceScale);

    // LMS Scheduler Step
    latents = scheduler.Step(noisePred, timesteps[t], latents);
}

使用 VAEDecoder 对 output 进行后处理

推理循环完成后,对结果张量进行缩放,然后发送到 vae_decoder 模型以解码图像。最后,将解码后的图像张量转换为图像并保存到磁盘。

public static Tensor<float> Decoder(List<NamedOnnxValue> input)
{
    // Load the model which will be used to decode the latents into image space. 
   var vaeDecoderModelPath = Directory.GetCurrentDirectory().ToString() + ("\\vae_decoder\\model.onnx");
    
    // Create an InferenceSession from the Model Path.
    var vaeDecodeSession = new InferenceSession(vaeDecoderModelPath);

   // Run session and send the input data in to get inference output. 
    var output = vaeDecodeSession.Run(input);
    var result = (output.ToList().First().Value as Tensor<float>);
    return result;
}

public static Image<Rgba32> ConvertToImage(Tensor<float> output, int width = 512, int height = 512, string imageName = "sample")
{
    var result = new Image<Rgba32>(width, height);
    for (var y = 0; y < height; y++)
    {
        for (var x = 0; x < width; x++)
        {
            result[x, y] = new Rgba32(
                (byte)(Math.Round(Math.Clamp((output[0, 0, y, x] / 2 + 0.5), 0, 1) * 255)),
                (byte)(Math.Round(Math.Clamp((output[0, 1, y, x] / 2 + 0.5), 0, 1) * 255)),
                (byte)(Math.Round(Math.Clamp((output[0, 2, y, x] / 2 + 0.5), 0, 1) * 255))
            );
        }
    }
    result.Save($@"C:/code/StableDiffusion/{imageName}.png");
    return result;
}

结果图像

image

结论

这是在 C# 中运行 Stable Diffusion 的一个高级概述。它涵盖了主要概念并提供了如何实现的示例。要获取完整代码,请查看 Stable Diffusion C# 示例

资源