在移动设备上使用机器学习超分辨率提高图像分辨率

了解如何使用 ONNX Runtime Mobile 构建应用程序以提高图像分辨率,该模型包含预处理和后处理。

您可以使用本教程为 Android 或 iOS 构建应用程序。

该应用程序接收图像输入,在单击按钮时执行超分辨率操作,并在下方显示分辨率提高的图像,如下面的屏幕截图所示。

Super resolution on a cat

目录

准备模型

本教程中使用的机器学习模型基于本页面底部引用的 PyTorch 教程中使用的模型。

我们提供了一个方便的 Python 脚本,可将 PyTorch 模型导出为 ONNX 格式,并添加预处理和后处理。

  1. 在运行此脚本之前,请安装以下 Python 包

     pip install torch
     pip install pillow
     pip install onnx
     pip install onnxruntime
     pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions
    

    关于版本的说明:最佳超分辨率结果是通过 ONNX opset 18(及其对具有抗锯齿功能的 Resize 运算符的支持)实现的,onnx 1.13.0 和 onnxruntime 1.14.0 及更高版本支持该版本。onnxruntime-extensions 包是预发布版本。发布版本即将推出。

  2. 然后从 onnxruntime-extensions GitHub 存储库下载脚本和测试图像(如果您尚未克隆此存储库)

     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py
     curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
    
  3. 运行脚本以导出核心模型并向其添加预处理和后处理

     python superresolution_e2e.py 
    

脚本运行后,您应该在运行脚本的位置的文件夹中看到两个 ONNX 文件

pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx

如果您将这两个模型加载到 netron 中,您可以看到两者之间输入和输出的差异。下面的前两张图像显示了原始模型,其输入是通道数据批次,而后两张图像显示了输入和输出是图像字节。

ONNX model without pre and post processing

ONNX model inputs and outputs without pre and post processing

ONNX model with pre and post processing

ONNX model inputs and outputs with pre and post processing

现在是编写应用程序代码的时候了。

Android 应用程序

先决条件

  • Android Studio Dolphin 2021.3.1 Patch +(安装在 Mac/Windows/Linux 上)
  • Android SDK 29+
  • Android NDK r22+
  • Android 设备或 Android 模拟器

示例代码

您可以在 GitHub 中找到 Android 超分辨率应用程序的完整源代码

要从源代码运行该应用程序,请克隆上述存储库并将 build.gradle 文件加载到 Android Studio 中,构建并运行!

要逐步构建应用程序,请按照以下部分进行操作。

从头开始编写代码

设置项目

在 Android Studio 中为手机和平板电脑创建一个新项目,然后选择空白模板。将应用程序命名为 super_resolution 或类似名称。

依赖项

将以下依赖项添加到应用程序 build.gradle

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'

项目资源

  1. 将模型文件添加为原始资源

    src/main/res 文件夹中创建一个名为 raw 的文件夹,并将 ONNX 模型移动或复制到 raw 文件夹中。

  2. 将测试图像添加为资产

    在主项目文件夹中创建一个名为 assets 的文件夹,并将您要运行超分辨率的图像复制到该文件夹中,文件名为 test_superresolution.png

主应用程序类代码

创建一个名为 MainActivity.kt 的文件,并将以下代码片段添加到其中。

  1. 添加导入语句

    import ai.onnxruntime.*
    import ai.onnxruntime.extensions.OrtxPackage
    import android.annotation.SuppressLint
    import android.os.Bundle
    import android.widget.Button
    import android.widget.ImageView
    import android.widget.Toast
    import androidx.activity.*
    import androidx.appcompat.app.AppCompatActivity
    import kotlinx.android.synthetic.main.activity_main.*
    import kotlinx.coroutines.*
    import java.io.InputStream
    import java.util.*
    import java.util.concurrent.ExecutorService
    import java.util.concurrent.Executors
    
  2. 创建主活动类并添加类变量

    class MainActivity : AppCompatActivity() {
        private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment()
        private lateinit var ortSession: OrtSession
        private var inputImage: ImageView? = null
        private var outputImage: ImageView? = null
        private var superResolutionButton: Button? = null
    
        ...
    }
    
  3. 添加 onCreate() 方法

    这是我们初始化 ONNX Runtime 会话 的地方。会话保存对用于在应用程序中执行推理的模型的引用。它还接受会话选项参数,您可以在其中指定不同的执行提供程序(硬件加速器,例如 NNAPI)。在本例中,我们默认在 CPU 上运行。但是,我们确实注册了自定义操作库,在该库中可以找到模型输入和输出端的图像编码和解码运算符。

     override fun onCreate(savedInstanceState: Bundle?) {
         super.onCreate(savedInstanceState)
         setContentView(R.layout.activity_main)
    
         inputImage = findViewById(R.id.imageView1)
         outputImage = findViewById(R.id.imageView2);
         superResolutionButton = findViewById(R.id.super_resolution_button)
         inputImage?.setImageBitmap(
             BitmapFactory.decodeStream(readInputImage())
         );
    
         // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators.
         // Note: These are used to decode the input image into the format the original model requires,
         // and to encode the model output into png format
         val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions()
         sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath())
         ortSession = ortEnv.createSession(readModel(), sessionOptions)
    
         superResolutionButton?.setOnClickListener {
             try {
                 performSuperResolution(ortSession)
                 Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT)
                     .show()
             } catch (e: Exception) {
                 Log.e(TAG, "Exception caught when perform super resolution", e)
                 Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT)
                     .show()
             }
         }
     }
    
  4. 添加 onDestroy 方法

     override fun onDestroy() {
         super.onDestroy()
         ortEnv.close()
         ortSession.close()
     }
    
    
  5. 添加 updateUI 方法

    private fun updateUI(result: Result) {
        outputImage?.setImageBitmap(result.outputBitmap)
    }
    
  6. 添加 readModel 方法

    此方法从 resources 文件夹读取 ONNX 模型。

    private fun readModel(): ByteArray {
        val modelID = R.pytorch_superresolution_with_pre_post_processing_op18
        return resources.openRawResource(modelID).readBytes()
    }   
    
  7. 添加读取输入图像的方法

    此方法从 assets 文件夹读取测试图像。目前,它读取内置于应用程序中的固定图像。示例很快将扩展为直接从相机或相机胶卷读取图像。

    private fun readInputImage(): InputStream {
        return assets.open("test_superresolution.png")
    }   
    
  8. 添加执行推理的方法

    此方法调用应用程序核心的方法:SuperResPerformer.upscale(),该方法在模型上运行推理。此代码在下一节中显示。

     private fun performSuperResolution(ortSession: OrtSession) {
         var superResPerformer = SuperResPerformer()
         var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession)
         updateUI(result);
     }   
    
  9. 添加 TAG 对象

    companion object {
        const val TAG = "ORTSuperResolution"
    }
    

模型推理类代码

创建一个名为 SuperResPerformer.kt 的文件,并将以下代码片段添加到其中。

  1. 添加导入

    import ai.onnxruntime.OnnxJavaType
    import ai.onnxruntime.OrtSession
    import ai.onnxruntime.OnnxTensor
    import ai.onnxruntime.OrtEnvironment
    import android.graphics.Bitmap
    import android.graphics.BitmapFactory
    import java.io.InputStream
    import java.nio.ByteBuffer
    import java.util.*
    
  2. 创建结果类

    internal data class Result(
        var outputBitmap: Bitmap? = null
    ) {}
    
  3. 创建超分辨率执行器类

    此类及其主函数 upscale 是 ONNX Runtime 大部分调用的地方。

    internal class SuperResPerformer(
    ) {
    
        fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result {
            var result = Result()
    
            // Step 1: convert image into byte array (raw image bytes)
            val rawImageBytes = inputStream.readBytes()
    
            // Step 2: get the shape of the byte array and make ort tensor
            val shape = longArrayOf(rawImageBytes.size.toLong())
    
            val inputTensor = OnnxTensor.createTensor(
                ortEnv,
                ByteBuffer.wrap(rawImageBytes),
                shape,
                OnnxJavaType.UINT8
            )
            inputTensor.use {
                // Step 3: call ort inferenceSession run
                val output = ortSession.run(Collections.singletonMap("image", inputTensor))
    
                // Step 4: output analysis
                output.use {
                    val rawOutput = (output?.get(0)?.value) as ByteArray
                    val outputImageBitmap =
                        byteArrayToBitmap(rawOutput)
    
                    // Step 5: set output result
                    result.outputBitmap = outputImageBitmap
                }
            }
            return result
        }
    

构建并运行应用程序

在 Android Studio 中

  • 选择 Build -> Make Project
  • Run -> app

应用程序在设备模拟器中运行。连接到您的 Android 设备以在设备上运行该应用程序。

iOS 应用程序

先决条件

  • 安装 Xcode 13.0 及更高版本(最好是最新版本)
  • iOS 设备或 iOS 模拟器
  • Xcode 命令行工具 xcode-select --install
  • CocoaPods sudo gem install cocoapods
  • 有效的 Apple Developer ID(如果您计划在设备上运行)

示例代码

您可以在 GitHub 中找到 iOS 超分辨率应用程序的完整源代码

要从源代码运行应用程序

  1. 克隆 onnxruntime-inference-examples 存储库

    git clone https://github.com/microsoft/onnxruntime-inference-examples
    cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
    
  2. 安装所需的 pod 文件

    pod install
    
  3. 在 XCode 中打开生成的 ORTSuperResolution.xcworkspace 文件

    (可选:仅当您在设备上运行时才需要)选择您的开发团队

  4. 运行应用程序

    连接您的 iOS 设备或模拟器,构建并运行应用程序

    单击 执行超分辨率 按钮以查看应用程序的实际效果

要逐步开发应用程序,请按照以下部分进行操作。

从头开始编写代码

创建项目

使用 APP 模板在 XCode 中创建一个新项目

依赖项

安装以下 pod

  # Pods for OrtSuperResolution
  pod 'onnxruntime-c'
  
  # Pre-release version pods
  pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'

项目资源

  1. 将模型文件添加到项目

    将本教程开始时生成的模型文件复制到项目文件夹的根目录。

  2. 将测试图像添加为资产

    将您要运行超分辨率的图像复制到项目文件夹的根目录。

主应用程序

打开名为 ORTSuperResolutionApp.swift 的文件并添加以下代码

import SwiftUI

@main
struct ORTSuperResolutionApp: App {
    var body: some Scene {
        WindowGroup {
            ContentView()
        }
    }
}

内容视图

打开名为 ContentView.swift 的文件并添加以下代码

import SwiftUI

struct ContentView: View {
    @State private var performSuperRes = false
    
    func runOrtSuperResolution() -> UIImage? {
        do {
            let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
            return outputImage
        } catch let error as NSError {
            print("Error: \(error.localizedDescription)")
            return nil
        }
    }
    
    var body: some View {
        ScrollView {
            VStack {
                VStack {
                    Text("ORTSuperResolution").font(.title).bold()
                        .frame(width: 400, height: 80)
                        .border(Color.purple, width: 4)
                        .background(Color.purple)
                    
                    Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                    
                    Image("cat_224x224").frame(width: 250, height: 250)
                    
                    Button("Perform Super Resolution") {
                        performSuperRes.toggle()
                    }
                    
                    if performSuperRes {
                        Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
                        
                        if let outputImage = runOrtSuperResolution() {
                            Image(uiImage: outputImage)
                        } else {
                            Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
                        }
                    }
                    Spacer()
                }
            }
            .padding()
        }
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

Swift / Objective C 桥接头文件

创建一个名为 ORTSuperResolution-Bridging-Header.h 的文件并添加以下导入语句

#import "ORTSuperResolutionPerformer.h"

超分辨率代码

  1. 创建一个名为 ORTSuperResolutionPerformer.h 的文件并添加以下代码

    #ifndef ORTSuperResolutionPerformer_h
    #define ORTSuperResolutionPerformer_h
    
    #import <Foundation/Foundation.h>
    #import <UIKit/UIKit.h>
    
    NS_ASSUME_NONNULL_BEGIN
    
    @interface ORTSuperResolutionPerformer : NSObject
    
    + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error;
    
    @end
    
    NS_ASSUME_NONNULL_END
    
    #endif
    
  2. 创建一个名为 ORTSuperResolutionPerformer.mm 的文件并添加以下代码

     #import "ORTSuperResolutionPerformer.h"
     #import <Foundation/Foundation.h>
     #import <UIKit/UIKit.h>
    
     #include <array>
     #include <cstdint>
     #include <stdexcept>
     #include <string>
     #include <vector>
    
     #include <onnxruntime_cxx_api.h>
     #include <onnxruntime_extensions.h>
    
    
     @implementation ORTSuperResolutionPerformer
    
     + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error {
            
         UIImage* output_image = nil;
            
         try {
                
             // Register custom ops
                
             const auto ort_log_level = ORT_LOGGING_LEVEL_INFO;
             auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution");
             auto session_options = Ort::SessionOptions();
                
             if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) {
                 throw std::runtime_error("RegisterCustomOps failed");
             }
                
             // Step 1: Load model
                
             NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16"
                                                                 ofType:@"onnx"];
             if (model_path == nullptr) {
                 throw std::runtime_error("Failed to get model path");
             }
                
             // Step 2: Create Ort Inference Session
                
             auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options);
                
             // Read input image
                
             // note: need to set Xcode settings to prevent it from messing with PNG files:
             // in "Build Settings":
             // - set "Compress PNG Files" to "No"
             // - set "Remove Text Metadata From PNG Files" to "No"
             NSString *input_image_path =
             [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"];
             if (input_image_path == nullptr) {
                 throw std::runtime_error("Failed to get image path");
             }
                
             // Step 3: Prepare input tensors and input/output names
                
             NSMutableData *input_data =
             [NSMutableData dataWithContentsOfFile:input_image_path];
             const int64_t input_data_length = input_data.length;
             const auto memoryInfo =
             Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
                
             const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length,
                                                             &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
                
             constexpr auto input_names = std::array{"image"};
             constexpr auto output_names = std::array{"image_out"};
                
             // Step 4: Call inference session run
                
             const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(),
                                         &input_tensor, 1, output_names.data(), 1);
             if (outputs.size() != 1) {
                 throw std::runtime_error("Unexpected number of outputs");
             }
                
             // Step 5: Analyze model outputs
                
             const auto &output_tensor = outputs.front();
             const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo();
             const auto output_shape = output_type_and_shape_info.GetShape();
                
             if (const auto output_element_type =
                 output_type_and_shape_info.GetElementType();
                 output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) {
                 throw std::runtime_error("Unexpected output element type");
             }
                
             const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>();
                
             // Step 6: Convert raw bytes into NSData and return as displayable UIImage
                
             NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])];
             output_image = [UIImage imageWithData:output_data];
                
         } catch (std::exception &e) {
             NSLog(@"%s error: %s", __FUNCTION__, e.what());
                
             static NSString *const kErrorDomain = @"ORTSuperResolution";
             constexpr NSInteger kErrorCode = 0;
             if (error) {
                 NSString *description =
                 [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding];
                 *error =
                 [NSError errorWithDomain:kErrorDomain
                                     code:kErrorCode
                                 userInfo:@{NSLocalizedDescriptionKey : description}];
             }
             return nullptr;
         }
            
         if (error) {
             *error = nullptr;
         }
         return output_image;
     }
    
     @end
    

构建并运行应用程序

在 XCode 中,选择三角形构建图标以构建并运行应用程序!

资源

原始 PyTorch 教程