使用机器学习超分辨率技术在移动设备上提高图像分辨率
了解如何使用 ONNX Runtime Mobile 构建一个包含预处理和后处理模型的应用来提高图像分辨率。
您可以使用此教程构建适用于 Android 或 iOS 的应用。
该应用接收图像输入,在点击按钮时执行超分辨率操作,并在下方显示分辨率提高后的图像,如下图所示。
目录
准备模型
本教程中使用的机器学习模型基于本页底部引用的 PyTorch 教程中使用的模型。
我们提供了一个方便的 Python 脚本,可以将 PyTorch 模型导出为 ONNX 格式并添加预处理和后处理。
-
在运行此脚本之前,请安装以下 Python 包
pip install torch pip install pillow pip install onnx pip install onnxruntime pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions
关于版本说明:使用 ONNX opset 18(支持带有抗锯齿功能的 Resize 操作符)可以获得最佳的超分辨率结果,该版本受 onnx 1.13.0 及 onnxruntime 1.14.0 及更高版本支持。onnxruntime-extensions 包目前是预发布版本,正式版本即将推出。
-
然后从 onnxruntime-extensions GitHub 仓库下载脚本和测试图像(如果您尚未克隆此仓库)
curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
-
运行脚本以导出核心模型并为其添加预处理和后处理
python superresolution_e2e.py
脚本运行后,您应该在运行脚本的位置看到两个 ONNX 文件
pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx
如果您将这两个模型加载到 netron 中,您可以看到它们在输入和输出上的区别。下面前两张图片显示了原始模型,其输入是通道数据的批次;后两张图片显示了输入和输出都是图像字节。
现在是时候编写应用代码了。
Android 应用
先决条件
- Android Studio Dolphin 2021.3.1 Patch +(安装在 Mac/Windows/Linux 上)
- Android SDK 29+
- Android NDK r22+
- 一台 Android 设备或一个 Android 模拟器
示例代码
您可以在 GitHub 上找到 Android 超分辨率应用的完整源代码。
要从源代码运行该应用,请克隆上面的仓库,并将 build.gradle
文件加载到 Android Studio 中,然后构建并运行!
要逐步构建该应用,请按照以下部分进行操作。
从零开始编写代码
项目设置
在 Android Studio 中为手机和平板电脑创建一个新项目,并选择空白模板。将应用命名为 super_resolution
或类似名称。
依赖项
将以下依赖项添加到应用的 build.gradle
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
项目资源
-
将模型文件添加为原始资源
在
src/main/res
文件夹中创建一个名为raw
的文件夹,然后将 ONNX 模型移动或复制到 raw 文件夹中。 -
将测试图像添加为资产
在主项目文件夹中创建一个名为
assets
的文件夹,并将您想要运行超分辨率的图像复制到该文件夹中,文件名为test_superresolution.png
主应用类代码
创建一个名为 MainActivity.kt 的文件,并将以下代码段添加到其中。
-
添加导入语句
import ai.onnxruntime.* import ai.onnxruntime.extensions.OrtxPackage import android.annotation.SuppressLint import android.os.Bundle import android.widget.Button import android.widget.ImageView import android.widget.Toast import androidx.activity.* import androidx.appcompat.app.AppCompatActivity import kotlinx.android.synthetic.main.activity_main.* import kotlinx.coroutines.* import java.io.InputStream import java.util.* import java.util.concurrent.ExecutorService import java.util.concurrent.Executors
-
创建主活动类并添加类变量
class MainActivity : AppCompatActivity() { private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private lateinit var ortSession: OrtSession private var inputImage: ImageView? = null private var outputImage: ImageView? = null private var superResolutionButton: Button? = null ... }
-
添加
onCreate()
方法在这里我们初始化 ONNX Runtime 会话。会话持有一个模型引用,用于在应用中执行推理。它还接受一个会话选项参数,您可以在其中指定不同的执行提供程序(如 NNAPI 等硬件加速器)。在这种情况下,我们默认在 CPU 上运行。但是,我们注册了包含模型输入和输出处图像编码和解码操作符的自定义操作库。
override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setContentView(R.layout.activity_main) inputImage = findViewById(R.id.imageView1) outputImage = findViewById(R.id.imageView2); superResolutionButton = findViewById(R.id.super_resolution_button) inputImage?.setImageBitmap( BitmapFactory.decodeStream(readInputImage()) ); // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators. // Note: These are used to decode the input image into the format the original model requires, // and to encode the model output into png format val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions() sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()) ortSession = ortEnv.createSession(readModel(), sessionOptions) superResolutionButton?.setOnClickListener { try { performSuperResolution(ortSession) Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT) .show() } catch (e: Exception) { Log.e(TAG, "Exception caught when perform super resolution", e) Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT) .show() } } }
-
添加 onDestroy 方法
override fun onDestroy() { super.onDestroy() ortEnv.close() ortSession.close() }
-
添加 updateUI 方法
private fun updateUI(result: Result) { outputImage?.setImageBitmap(result.outputBitmap) }
-
添加 readModel 方法
此方法从资源文件夹读取 ONNX 模型。
private fun readModel(): ByteArray { val modelID = R.pytorch_superresolution_with_pre_post_processing_op18 return resources.openRawResource(modelID).readBytes() }
-
添加读取输入图像的方法
此方法从 assets 文件夹读取测试图像。目前它读取的是内置到应用中的固定图像。很快示例将扩展到直接从相机或相册读取图像。
private fun readInputImage(): InputStream { return assets.open("test_superresolution.png") }
-
添加执行推理的方法
此方法调用了应用的核心方法:
SuperResPerformer.upscale()
,该方法在模型上运行推理。此方法的代码将在下一节中展示。private fun performSuperResolution(ortSession: OrtSession) { var superResPerformer = SuperResPerformer() var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession) updateUI(result); }
-
添加 TAG 对象
companion object { const val TAG = "ORTSuperResolution" }
模型推理类代码
创建一个名为 SuperResPerformer.kt
的文件,并将以下代码段添加到其中。
-
添加导入
import ai.onnxruntime.OnnxJavaType import ai.onnxruntime.OrtSession import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import android.graphics.Bitmap import android.graphics.BitmapFactory import java.io.InputStream import java.nio.ByteBuffer import java.util.*
-
创建结果类
internal data class Result( var outputBitmap: Bitmap? = null ) {}
-
创建超分辨率执行器类
这个类及其主要函数
upscale
是大多数 ONNX Runtime 调用所在的地方。- OrtEnvironment 单例维护环境属性和配置的日志级别
- 使用 OnnxTensor.createTensor() 创建由输入图像字节组成的张量,适合作为模型的输入
- OnnxJavaType.UINT8 是输入张量的 ByteBuffer 的数据类型
- 使用 OrtSession.run() 在模型上运行推理(预测),以获取输出的放大图像
internal class SuperResPerformer( ) { fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result { var result = Result() // Step 1: convert image into byte array (raw image bytes) val rawImageBytes = inputStream.readBytes() // Step 2: get the shape of the byte array and make ort tensor val shape = longArrayOf(rawImageBytes.size.toLong()) val inputTensor = OnnxTensor.createTensor( ortEnv, ByteBuffer.wrap(rawImageBytes), shape, OnnxJavaType.UINT8 ) inputTensor.use { // Step 3: call ort inferenceSession run val output = ortSession.run(Collections.singletonMap("image", inputTensor)) // Step 4: output analysis output.use { val rawOutput = (output?.get(0)?.value) as ByteArray val outputImageBitmap = byteArrayToBitmap(rawOutput) // Step 5: set output result result.outputBitmap = outputImageBitmap } } return result }
构建并运行应用
在 Android Studio 中
- 选择 Build -> Make Project
- 选择 Run -> app
应用在设备模拟器中运行。连接您的 Android 设备以在设备上运行应用。
iOS 应用
先决条件
- 安装 Xcode 13.0 及更高版本(最好是最新版本)
- 一台 iOS 设备或一个 iOS 模拟器
- Xcode 命令行工具
xcode-select --install
- CocoaPods
sudo gem install cocoapods
- 有效的 Apple Developer ID(如果您计划在设备上运行)
示例代码
您可以在 GitHub 上找到 iOS 超分辨率应用的完整源代码。
要从源代码运行该应用
-
克隆 onnxruntime-inference-examples 仓库
git clone https://github.com/microsoft/onnxruntime-inference-examples cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
-
安装所需的 pod 文件
pod install
-
在 XCode 中打开生成的
ORTSuperResolution.xcworkspace
文件(可选:仅在设备上运行时需要)选择您的开发团队
-
运行应用
连接您的 iOS 设备或模拟器,构建并运行应用
点击
Perform Super Resolution
按钮查看应用效果
要逐步开发该应用,请按照以下部分进行操作。
从零开始编写代码
创建项目
在 XCode 中使用 APP 模板创建一个新项目
依赖项
安装以下 pods
# Pods for OrtSuperResolution
pod 'onnxruntime-c'
# Pre-release version pods
pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'
项目资源
-
将模型文件添加到项目
将本教程开头生成的模型文件复制到项目文件夹的根目录。
-
将测试图像添加为资产
将您想要运行超分辨率的图像复制到项目文件夹的根目录。
主应用
打开名为 ORTSuperResolutionApp.swift
的文件,并添加以下代码
import SwiftUI
@main
struct ORTSuperResolutionApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}
内容视图
打开名为 ContentView.swift
的文件,并添加以下代码
import SwiftUI
struct ContentView: View {
@State private var performSuperRes = false
func runOrtSuperResolution() -> UIImage? {
do {
let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
return outputImage
} catch let error as NSError {
print("Error: \(error.localizedDescription)")
return nil
}
}
var body: some View {
ScrollView {
VStack {
VStack {
Text("ORTSuperResolution").font(.title).bold()
.frame(width: 400, height: 80)
.border(Color.purple, width: 4)
.background(Color.purple)
Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
Image("cat_224x224").frame(width: 250, height: 250)
Button("Perform Super Resolution") {
performSuperRes.toggle()
}
if performSuperRes {
Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
if let outputImage = runOrtSuperResolution() {
Image(uiImage: outputImage)
} else {
Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
}
}
Spacer()
}
}
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
Swift / Objective C 桥接头文件
创建一个名为 ORTSuperResolution-Bridging-Header.h
的文件,并添加以下导入语句
#import "ORTSuperResolutionPerformer.h"
超分辨率代码
-
创建一个名为
ORTSuperResolutionPerformer.h
的文件,并添加以下代码#ifndef ORTSuperResolutionPerformer_h #define ORTSuperResolutionPerformer_h #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> NS_ASSUME_NONNULL_BEGIN @interface ORTSuperResolutionPerformer : NSObject + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error; @end NS_ASSUME_NONNULL_END #endif
-
创建一个名为
ORTSuperResolutionPerformer.mm
的文件,并添加以下代码#import "ORTSuperResolutionPerformer.h" #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> #include <array> #include <cstdint> #include <stdexcept> #include <string> #include <vector> #include <onnxruntime_cxx_api.h> #include <onnxruntime_extensions.h> @implementation ORTSuperResolutionPerformer + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error { UIImage* output_image = nil; try { // Register custom ops const auto ort_log_level = ORT_LOGGING_LEVEL_INFO; auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution"); auto session_options = Ort::SessionOptions(); if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) { throw std::runtime_error("RegisterCustomOps failed"); } // Step 1: Load model NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16" ofType:@"onnx"]; if (model_path == nullptr) { throw std::runtime_error("Failed to get model path"); } // Step 2: Create Ort Inference Session auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options); // Read input image // note: need to set Xcode settings to prevent it from messing with PNG files: // in "Build Settings": // - set "Compress PNG Files" to "No" // - set "Remove Text Metadata From PNG Files" to "No" NSString *input_image_path = [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"]; if (input_image_path == nullptr) { throw std::runtime_error("Failed to get image path"); } // Step 3: Prepare input tensors and input/output names NSMutableData *input_data = [NSMutableData dataWithContentsOfFile:input_image_path]; const int64_t input_data_length = input_data.length; const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length, &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); constexpr auto input_names = std::array{"image"}; constexpr auto output_names = std::array{"image_out"}; // Step 4: Call inference session run const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(), &input_tensor, 1, output_names.data(), 1); if (outputs.size() != 1) { throw std::runtime_error("Unexpected number of outputs"); } // Step 5: Analyze model outputs const auto &output_tensor = outputs.front(); const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); const auto output_shape = output_type_and_shape_info.GetShape(); if (const auto output_element_type = output_type_and_shape_info.GetElementType(); output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { throw std::runtime_error("Unexpected output element type"); } const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>(); // Step 6: Convert raw bytes into NSData and return as displayable UIImage NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])]; output_image = [UIImage imageWithData:output_data]; } catch (std::exception &e) { NSLog(@"%s error: %s", __FUNCTION__, e.what()); static NSString *const kErrorDomain = @"ORTSuperResolution"; constexpr NSInteger kErrorCode = 0; if (error) { NSString *description = [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding]; *error = [NSError errorWithDomain:kErrorDomain code:kErrorCode userInfo:@{NSLocalizedDescriptionKey : description}]; } return nullptr; } if (error) { *error = nullptr; } return output_image; } @end
构建并运行应用
在 XCode 中,点击三角形构建图标来构建并运行应用!