在移动设备上使用机器学习超分辨率提高图像分辨率
了解如何使用 ONNX Runtime Mobile 构建应用程序以提高图像分辨率,该模型包含预处理和后处理。
您可以使用本教程为 Android 或 iOS 构建应用程序。
该应用程序接收图像输入,在单击按钮时执行超分辨率操作,并在下方显示分辨率提高的图像,如下面的屏幕截图所示。
目录
准备模型
本教程中使用的机器学习模型基于本页面底部引用的 PyTorch 教程中使用的模型。
我们提供了一个方便的 Python 脚本,可将 PyTorch 模型导出为 ONNX 格式,并添加预处理和后处理。
-
在运行此脚本之前,请安装以下 Python 包
pip install torch pip install pillow pip install onnx pip install onnxruntime pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions
关于版本的说明:最佳超分辨率结果是通过 ONNX opset 18(及其对具有抗锯齿功能的 Resize 运算符的支持)实现的,onnx 1.13.0 和 onnxruntime 1.14.0 及更高版本支持该版本。onnxruntime-extensions 包是预发布版本。发布版本即将推出。
-
然后从 onnxruntime-extensions GitHub 存储库下载脚本和测试图像(如果您尚未克隆此存储库)
curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/superresolution_e2e.py > superresolution_e2e.py curl https://raw.githubusercontent.com/microsoft/onnxruntime-extensions/main/tutorials/data/super_res_input.png > data/super_res_input.png
-
运行脚本以导出核心模型并向其添加预处理和后处理
python superresolution_e2e.py
脚本运行后,您应该在运行脚本的位置的文件夹中看到两个 ONNX 文件
pytorch_superresolution.onnx
pytorch_superresolution_with_pre_and_post_processing.onnx
如果您将这两个模型加载到 netron 中,您可以看到两者之间输入和输出的差异。下面的前两张图像显示了原始模型,其输入是通道数据批次,而后两张图像显示了输入和输出是图像字节。
现在是编写应用程序代码的时候了。
Android 应用程序
先决条件
- Android Studio Dolphin 2021.3.1 Patch +(安装在 Mac/Windows/Linux 上)
- Android SDK 29+
- Android NDK r22+
- Android 设备或 Android 模拟器
示例代码
您可以在 GitHub 中找到 Android 超分辨率应用程序的完整源代码。
要从源代码运行该应用程序,请克隆上述存储库并将 build.gradle
文件加载到 Android Studio 中,构建并运行!
要逐步构建应用程序,请按照以下部分进行操作。
从头开始编写代码
设置项目
在 Android Studio 中为手机和平板电脑创建一个新项目,然后选择空白模板。将应用程序命名为 super_resolution
或类似名称。
依赖项
将以下依赖项添加到应用程序 build.gradle
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
项目资源
-
将模型文件添加为原始资源
在
src/main/res
文件夹中创建一个名为raw
的文件夹,并将 ONNX 模型移动或复制到 raw 文件夹中。 -
将测试图像添加为资产
在主项目文件夹中创建一个名为
assets
的文件夹,并将您要运行超分辨率的图像复制到该文件夹中,文件名为test_superresolution.png
主应用程序类代码
创建一个名为 MainActivity.kt 的文件,并将以下代码片段添加到其中。
-
添加导入语句
import ai.onnxruntime.* import ai.onnxruntime.extensions.OrtxPackage import android.annotation.SuppressLint import android.os.Bundle import android.widget.Button import android.widget.ImageView import android.widget.Toast import androidx.activity.* import androidx.appcompat.app.AppCompatActivity import kotlinx.android.synthetic.main.activity_main.* import kotlinx.coroutines.* import java.io.InputStream import java.util.* import java.util.concurrent.ExecutorService import java.util.concurrent.Executors
-
创建主活动类并添加类变量
class MainActivity : AppCompatActivity() { private var ortEnv: OrtEnvironment = OrtEnvironment.getEnvironment() private lateinit var ortSession: OrtSession private var inputImage: ImageView? = null private var outputImage: ImageView? = null private var superResolutionButton: Button? = null ... }
-
添加
onCreate()
方法这是我们初始化 ONNX Runtime 会话 的地方。会话保存对用于在应用程序中执行推理的模型的引用。它还接受会话选项参数,您可以在其中指定不同的执行提供程序(硬件加速器,例如 NNAPI)。在本例中,我们默认在 CPU 上运行。但是,我们确实注册了自定义操作库,在该库中可以找到模型输入和输出端的图像编码和解码运算符。
override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setContentView(R.layout.activity_main) inputImage = findViewById(R.id.imageView1) outputImage = findViewById(R.id.imageView2); superResolutionButton = findViewById(R.id.super_resolution_button) inputImage?.setImageBitmap( BitmapFactory.decodeStream(readInputImage()) ); // Initialize Ort Session and register the onnxruntime extensions package that contains the custom operators. // Note: These are used to decode the input image into the format the original model requires, // and to encode the model output into png format val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions() sessionOptions.registerCustomOpLibrary(OrtxPackage.getLibraryPath()) ortSession = ortEnv.createSession(readModel(), sessionOptions) superResolutionButton?.setOnClickListener { try { performSuperResolution(ortSession) Toast.makeText(baseContext, "Super resolution performed!", Toast.LENGTH_SHORT) .show() } catch (e: Exception) { Log.e(TAG, "Exception caught when perform super resolution", e) Toast.makeText(baseContext, "Failed to perform super resolution", Toast.LENGTH_SHORT) .show() } } }
-
添加 onDestroy 方法
override fun onDestroy() { super.onDestroy() ortEnv.close() ortSession.close() }
-
添加 updateUI 方法
private fun updateUI(result: Result) { outputImage?.setImageBitmap(result.outputBitmap) }
-
添加 readModel 方法
此方法从 resources 文件夹读取 ONNX 模型。
private fun readModel(): ByteArray { val modelID = R.pytorch_superresolution_with_pre_post_processing_op18 return resources.openRawResource(modelID).readBytes() }
-
添加读取输入图像的方法
此方法从 assets 文件夹读取测试图像。目前,它读取内置于应用程序中的固定图像。示例很快将扩展为直接从相机或相机胶卷读取图像。
private fun readInputImage(): InputStream { return assets.open("test_superresolution.png") }
-
添加执行推理的方法
此方法调用应用程序核心的方法:
SuperResPerformer.upscale()
,该方法在模型上运行推理。此代码在下一节中显示。private fun performSuperResolution(ortSession: OrtSession) { var superResPerformer = SuperResPerformer() var result = superResPerformer.upscale(readInputImage(), ortEnv, ortSession) updateUI(result); }
-
添加 TAG 对象
companion object { const val TAG = "ORTSuperResolution" }
模型推理类代码
创建一个名为 SuperResPerformer.kt
的文件,并将以下代码片段添加到其中。
-
添加导入
import ai.onnxruntime.OnnxJavaType import ai.onnxruntime.OrtSession import ai.onnxruntime.OnnxTensor import ai.onnxruntime.OrtEnvironment import android.graphics.Bitmap import android.graphics.BitmapFactory import java.io.InputStream import java.nio.ByteBuffer import java.util.*
-
创建结果类
internal data class Result( var outputBitmap: Bitmap? = null ) {}
-
创建超分辨率执行器类
此类及其主函数
upscale
是 ONNX Runtime 大部分调用的地方。- OrtEnvironment 单例维护环境属性和配置的日志记录级别
- OnnxTensor.createTensor() 用于创建由输入图像字节组成的张量,适合作为模型的输入
- OnnxJavaType.UINT8 是输入张量的 ByteBuffer 的数据类型
- OrtSession.run() 在模型上运行推理(预测)以获取输出的放大图像
internal class SuperResPerformer( ) { fun upscale(inputStream: InputStream, ortEnv: OrtEnvironment, ortSession: OrtSession): Result { var result = Result() // Step 1: convert image into byte array (raw image bytes) val rawImageBytes = inputStream.readBytes() // Step 2: get the shape of the byte array and make ort tensor val shape = longArrayOf(rawImageBytes.size.toLong()) val inputTensor = OnnxTensor.createTensor( ortEnv, ByteBuffer.wrap(rawImageBytes), shape, OnnxJavaType.UINT8 ) inputTensor.use { // Step 3: call ort inferenceSession run val output = ortSession.run(Collections.singletonMap("image", inputTensor)) // Step 4: output analysis output.use { val rawOutput = (output?.get(0)?.value) as ByteArray val outputImageBitmap = byteArrayToBitmap(rawOutput) // Step 5: set output result result.outputBitmap = outputImageBitmap } } return result }
构建并运行应用程序
在 Android Studio 中
- 选择 Build -> Make Project
- Run -> app
应用程序在设备模拟器中运行。连接到您的 Android 设备以在设备上运行该应用程序。
iOS 应用程序
先决条件
- 安装 Xcode 13.0 及更高版本(最好是最新版本)
- iOS 设备或 iOS 模拟器
- Xcode 命令行工具
xcode-select --install
- CocoaPods
sudo gem install cocoapods
- 有效的 Apple Developer ID(如果您计划在设备上运行)
示例代码
您可以在 GitHub 中找到 iOS 超分辨率应用程序的完整源代码。
要从源代码运行应用程序
-
克隆 onnxruntime-inference-examples 存储库
git clone https://github.com/microsoft/onnxruntime-inference-examples cd onnxruntime-inference-examples/mobile/examples/super_resolution/ios
-
安装所需的 pod 文件
pod install
-
在 XCode 中打开生成的
ORTSuperResolution.xcworkspace
文件(可选:仅当您在设备上运行时才需要)选择您的开发团队
-
运行应用程序
连接您的 iOS 设备或模拟器,构建并运行应用程序
单击
执行超分辨率
按钮以查看应用程序的实际效果
要逐步开发应用程序,请按照以下部分进行操作。
从头开始编写代码
创建项目
使用 APP 模板在 XCode 中创建一个新项目
依赖项
安装以下 pod
# Pods for OrtSuperResolution
pod 'onnxruntime-c'
# Pre-release version pods
pod 'onnxruntime-extensions-c', '0.5.0-dev+261962.e3663fb'
项目资源
-
将模型文件添加到项目
将本教程开始时生成的模型文件复制到项目文件夹的根目录。
-
将测试图像添加为资产
将您要运行超分辨率的图像复制到项目文件夹的根目录。
主应用程序
打开名为 ORTSuperResolutionApp.swift
的文件并添加以下代码
import SwiftUI
@main
struct ORTSuperResolutionApp: App {
var body: some Scene {
WindowGroup {
ContentView()
}
}
}
内容视图
打开名为 ContentView.swift
的文件并添加以下代码
import SwiftUI
struct ContentView: View {
@State private var performSuperRes = false
func runOrtSuperResolution() -> UIImage? {
do {
let outputImage = try ORTSuperResolutionPerformer.performSuperResolution()
return outputImage
} catch let error as NSError {
print("Error: \(error.localizedDescription)")
return nil
}
}
var body: some View {
ScrollView {
VStack {
VStack {
Text("ORTSuperResolution").font(.title).bold()
.frame(width: 400, height: 80)
.border(Color.purple, width: 4)
.background(Color.purple)
Text("Input low resolution image: ").frame(width: 350, height: 40, alignment:.leading)
Image("cat_224x224").frame(width: 250, height: 250)
Button("Perform Super Resolution") {
performSuperRes.toggle()
}
if performSuperRes {
Text("Output high resolution image: ").frame(width: 350, height: 40, alignment:.leading)
if let outputImage = runOrtSuperResolution() {
Image(uiImage: outputImage)
} else {
Text("Unable to perform super resolution. ").frame(width: 350, height: 40, alignment:.leading)
}
}
Spacer()
}
}
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
Swift / Objective C 桥接头文件
创建一个名为 ORTSuperResolution-Bridging-Header.h
的文件并添加以下导入语句
#import "ORTSuperResolutionPerformer.h"
超分辨率代码
-
创建一个名为
ORTSuperResolutionPerformer.h
的文件并添加以下代码#ifndef ORTSuperResolutionPerformer_h #define ORTSuperResolutionPerformer_h #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> NS_ASSUME_NONNULL_BEGIN @interface ORTSuperResolutionPerformer : NSObject + (nullable UIImage*)performSuperResolutionWithError:(NSError**)error; @end NS_ASSUME_NONNULL_END #endif
-
创建一个名为
ORTSuperResolutionPerformer.mm
的文件并添加以下代码#import "ORTSuperResolutionPerformer.h" #import <Foundation/Foundation.h> #import <UIKit/UIKit.h> #include <array> #include <cstdint> #include <stdexcept> #include <string> #include <vector> #include <onnxruntime_cxx_api.h> #include <onnxruntime_extensions.h> @implementation ORTSuperResolutionPerformer + (nullable UIImage*)performSuperResolutionWithError:(NSError **)error { UIImage* output_image = nil; try { // Register custom ops const auto ort_log_level = ORT_LOGGING_LEVEL_INFO; auto ort_env = Ort::Env(ort_log_level, "ORTSuperResolution"); auto session_options = Ort::SessionOptions(); if (RegisterCustomOps(session_options, OrtGetApiBase()) != nullptr) { throw std::runtime_error("RegisterCustomOps failed"); } // Step 1: Load model NSString *model_path = [NSBundle.mainBundle pathForResource:@"pt_super_resolution_with_pre_post_processing_opset16" ofType:@"onnx"]; if (model_path == nullptr) { throw std::runtime_error("Failed to get model path"); } // Step 2: Create Ort Inference Session auto sess = Ort::Session(ort_env, [model_path UTF8String], session_options); // Read input image // note: need to set Xcode settings to prevent it from messing with PNG files: // in "Build Settings": // - set "Compress PNG Files" to "No" // - set "Remove Text Metadata From PNG Files" to "No" NSString *input_image_path = [NSBundle.mainBundle pathForResource:@"cat_224x224" ofType:@"png"]; if (input_image_path == nullptr) { throw std::runtime_error("Failed to get image path"); } // Step 3: Prepare input tensors and input/output names NSMutableData *input_data = [NSMutableData dataWithContentsOfFile:input_image_path]; const int64_t input_data_length = input_data.length; const auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); const auto input_tensor = Ort::Value::CreateTensor(memoryInfo, [input_data mutableBytes], input_data_length, &input_data_length, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8); constexpr auto input_names = std::array{"image"}; constexpr auto output_names = std::array{"image_out"}; // Step 4: Call inference session run const auto outputs = sess.Run(Ort::RunOptions(), input_names.data(), &input_tensor, 1, output_names.data(), 1); if (outputs.size() != 1) { throw std::runtime_error("Unexpected number of outputs"); } // Step 5: Analyze model outputs const auto &output_tensor = outputs.front(); const auto output_type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); const auto output_shape = output_type_and_shape_info.GetShape(); if (const auto output_element_type = output_type_and_shape_info.GetElementType(); output_element_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) { throw std::runtime_error("Unexpected output element type"); } const uint8_t *output_data_raw = output_tensor.GetTensorData<uint8_t>(); // Step 6: Convert raw bytes into NSData and return as displayable UIImage NSData *output_data = [NSData dataWithBytes:output_data_raw length:(output_shape[0])]; output_image = [UIImage imageWithData:output_data]; } catch (std::exception &e) { NSLog(@"%s error: %s", __FUNCTION__, e.what()); static NSString *const kErrorDomain = @"ORTSuperResolution"; constexpr NSInteger kErrorCode = 0; if (error) { NSString *description = [NSString stringWithCString:e.what() encoding:NSASCIIStringEncoding]; *error = [NSError errorWithDomain:kErrorDomain code:kErrorCode userInfo:@{NSLocalizedDescriptionKey : description}]; } return nullptr; } if (error) { *error = nullptr; } return output_image; } @end
构建并运行应用程序
在 XCode 中,选择三角形构建图标以构建并运行应用程序!