设备端训练:构建 Android 应用程序
在本教程中,我们将探讨如何构建一个集成 ONNX Runtime 设备端训练解决方案的 Android 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型的过程,而无需依赖云服务或外部服务器。
以下是本教程结束时应用程序的外观
简介
我们将指导您完成创建 Android 应用程序的步骤,该应用程序可以使用设备端训练技术训练简单的图像分类模型。本教程展示了 迁移学习
技术,其中从在一个任务上训练模型获得的知识被用来提高模型在不同但相关的任务上的性能。迁移学习不是从头开始学习过程,而是允许我们将预训练模型学习到的知识或特征迁移到新任务。
在本教程中,我们将利用 MobileNetV2
模型,该模型已在 ImageNet 等大规模图像数据集(具有 1,000 个类别)上进行了训练。我们将使用此模型将自定义数据分类为四个类别之一。MobileNetV2 的初始层充当特征提取器,捕获适用于各种任务的通用视觉特征,只有最终的分类器层将针对手头的任务进行训练。
在本教程中,我们将使用数据学习
- 使用预先打包的动物数据集将动物分类为四个类别之一。
- 使用自定义名人数据集将名人分类为四个类别之一。
目录
先决条件
要学习本教程,您应该基本了解如何使用 Java 或 Kotlin 开发 Android 应用程序。熟悉 C++ 以及机器学习概念(如神经网络和图像分类)也将有所帮助。
- 用于准备训练工件的 Python 开发环境
- Android Studio 4.1+
- Android SDK 29+
- Android NDK r21+
- 具有摄像头的 Android 设备,处于开发者模式并启用 USB 调试
注意 整个 Android 应用程序也在
onnxruntime-training-examples
GitHub 存储库上提供。
离线阶段 - 构建训练工件
-
我们从预训练的 PyTorch 模型开始,并将其导出到 ONNX。
MobileNetV2
模型已在 imagenet 数据集上预训练,该数据集具有 1000 个类别的数据。对于我们的图像分类任务,我们只想将图像分类为 4 个类别。因此,我们将模型的最后一层更改为输出 4 个 logits 而不是 1,000 个。有关如何将 PyTorch 模型导出到 ONNX 的更多详细信息,请参见此处。
import torch import torchvision model = torchvision.models.mobilenet_v2( weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2) # The original model is trained on imagenet which has 1000 classes. # For our image classification scenario, we need to classify among 4 categories. # So we need to change the last layer of the model to have 4 outputs. model.classifier[1] = torch.nn.Linear(1280, 4) # Export the model to ONNX. model_name = "mobilenetv2" torch.onnx.export(model, torch.randn(1, 3, 224, 224), f"training_artifacts/{model_name}.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
-
import onnx # Load the onnx model. onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx") # Define the parameters that require their gradients to be computed # (trainable parameters) and those that do not (frozen/non trainable parameters). requires_grad = ["classifier.1.weight", "classifier.1.bias"] frozen_params = [ param.name for param in onnx_model.graph.initializer if param.name not in requires_grad ]
-
在本教程中,我们将使用
CrossEntropyLoss
损失和AdamW
优化器。有关工件生成的更多详细信息,请参见此处。from onnxruntime.training import artifacts # Generate the training artifacts. artifacts.generate_artifacts( onnx_model, requires_grad=requires_grad, frozen_params=frozen_params, loss=artifacts.LossType.CrossEntropyLoss, optimizer=artifacts.OptimType.AdamW, artifact_directory="training_artifacts" )
就这样!训练工件已在
training_artifacts
文件夹中生成。这标志着离线阶段的结束。这些工件已准备好部署到 Android 设备进行训练。
训练阶段 - Android 应用程序开发
-
a. 打开 Android Studio 并单击
New Project
b. 单击
Native C++
->Next
。填写New Project
详细信息,如下所示- 名称 -
ORT Personalize
- 包名 -
com.example.ortpersonalize
- 语言 -
Kotlin
单击
Next
。c. 选择
C++17
工具链 ->Finish
d. 完成!Android Studio 项目已设置完成。您现在应该能够看到带有某些样板代码的 Android Studio 编辑器。
- 名称 -
-
a. 在 Android Studio 项目中的 cpp 目录下创建两个名为
lib
和include\onnxruntime
的新文件夹。b. 前往 Maven Central。转到
Versions
->Browse
-> 并下载onnxruntime-training-android
存档包(aar 文件)。c. 将
aar
扩展名重命名为zip
。因此onnxruntime-training-android-1.15.0.aar
变为onnxruntime-training-android-1.15.0.zip
。d. 解压缩 zip 文件的内容。
e. 将
libonnxruntime.so
共享库从jni\arm64-v8a
文件夹复制到 Android 项目中新创建的lib
文件夹下。f. 将
headers
文件夹的内容复制到新创建的include\onnxruntime
文件夹。g. 在
native-lib.cpp
文件中,包含训练 cxx 头文件。#include "onnxruntime_training_cxx_api.h"
h. 将
abiFilters
添加到build.gradle (Module)
文件,以便选择arm64-v8a
。此设置必须在build.gradle
中的defaultConfig
下添加ndk { abiFilters 'arm64-v8a' }
请注意,
build.gradle
文件的defaultConfig
部分应如下所示defaultConfig { applicationId "com.example.ortpersonalize" minSdk 29 targetSdk 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" externalNativeBuild { cmake { cppFlags '-std=c++17' } } + ndk { + abiFilters 'arm64-v8a' + } }
i. 将
onnxruntime
共享库添加到CMakeLists.txt
,以便cmake
可以找到并针对共享库进行构建。为此,在CMakeLists.txt
中添加ortpersonalize
库后,添加以下行add_library(onnxruntime SHARED IMPORTED) set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
通过在上面两行之后添加此行,让
CMake
知道 ONNX Runtime 头文件可以在哪里找到target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
通过将
onnxruntime
库添加到target_link_libraries
,将 Android C++ 项目链接到onnxruntime
库target_link_libraries( # Specifies the target library. ortpersonalize # Links the target library to the log library # included in the NDK. ${log-lib} onnxruntime)
请注意,
CMakeLists.txt
文件应如下所示project("ortpersonalize") add_library( # Sets the name of the library. ortpersonalize # Sets the library as a shared library. SHARED # Provides a relative path to your source file(s). native-lib.cpp + utils.cpp + inference.cpp + train.cpp) + add_library(onnxruntime SHARED IMPORTED) + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so) + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime) find_library( # Sets the name of the path variable. log-lib # Specifies the name of the NDK library that # you want CMake to locate. log) target_link_libraries( # Specifies the target library. ortpersonalize # Links the target library to the log library # included in the NDK. ${log-lib} + onnxruntime)
j. 构建应用程序并等待成功,以确认应用程序已包含 ONNX Runtime 头文件,并且可以成功链接到共享 onnxruntime 库。
-
a. 通过右键单击 app -> New -> Folder -> Assets Folder 并将其放在 main 下,在 Android Studio 项目的左侧窗格中创建一个新的
assets
文件夹。b. 将步骤 2 中生成的训练工件复制到此文件夹。
c. 现在,前往
onnxruntime-training-examples
存储库并下载数据集 (images.zip
) 到您的计算机并解压缩。此数据集是从 Kaggle 上提供的原始animals-10
数据集修改而来的,由 Corrado Alessio 创建。d. 将下载的
images
文件夹复制到 Android Studio 中的assets/images
目录。项目的左侧窗格应如下所示
-
a. 我们将在 C++ 中实现以下四个函数,这些函数将从应用程序中调用
createSession
:将在应用程序启动时调用。它将创建新的CheckpointState
和TrainingSession
对象。releaseSession
:将在应用程序即将关闭时调用。此函数将释放应用程序启动时分配的资源。performTraining
:将在用户单击 UI 上的Train
按钮时调用。performInference
:将在用户单击 UI 上的Infer
按钮时调用。
b. 创建会话
此函数在应用程序启动时调用。这将使用训练工件 assets 来创建
C++
CheckpointState 和 TrainingSession 对象。这些对象将用于在设备上训练模型。createSession
的参数是checkpoint_path
:检查点工件的缓存路径。train_model_path
:训练模型工件的缓存路径。eval_model_path
:评估模型工件的缓存路径。optimizer_model_path
:优化器模型工件的缓存路径。cache_dir_path
:Android 设备上缓存目录的路径。缓存目录用作从 C++ 代码访问训练工件的方式。
该函数返回一个
long
,表示指向session_cache
对象的指针。每当我们需要访问训练会话时,此long
都可以强制转换为SessionCache
。extern "C" JNIEXPORT jlong JNICALL Java_com_example_ortpersonalize_MainActivity_createSession( JNIEnv *env, jobject /* this */, jstring checkpoint_path, jstring train_model_path, jstring eval_model_path, jstring optimizer_model_path, jstring cache_dir_path) { std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>( utils::JString2String(env, checkpoint_path), utils::JString2String(env, train_model_path), utils::JString2String(env, eval_model_path), utils::JString2String(env, optimizer_model_path), utils::JString2String(env, cache_dir_path)); return reinterpret_cast<long>(session_cache.release()); }
从上面的函数体可以看出,此函数创建一个指向
SessionCache
类对象的唯一指针。SessionCache
的定义如下所示。struct SessionCache { ArtifactPaths artifact_paths; Ort::Env ort_env; Ort::SessionOptions session_options; Ort::CheckpointState checkpoint_state; Ort::TrainingSession training_session; Ort::Session* inference_session; SessionCache(const std::string &checkpoint_path, const std::string &training_model_path, const std::string &eval_model_path, const std::string &optimizer_model_path, const std::string& cache_dir_path) : artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path), ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(), checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())), training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(), artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()), inference_session(nullptr) {} };
ArtifactPaths
的定义是struct ArtifactPaths { std::string checkpoint_path; std::string training_model_path; std::string eval_model_path; std::string optimizer_model_path; std::string cache_dir_path; std::string inference_model_path; ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path, const std::string &eval_model_path, const std::string &optimizer_model_path, const std::string& cache_dir_path) : checkpoint_path(checkpoint_path), training_model_path(training_model_path), eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path), cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {} };
c. 释放会话
此函数在应用程序即将关闭时调用。它释放应用程序启动时创建的资源,主要是 CheckpointState 和 TrainingSession。
releaseSession
的参数是session
:long
,表示SessionCache
对象。
extern "C" JNIEXPORT void JNICALL Java_com_example_ortpersonalize_MainActivity_releaseSession( JNIEnv *env, jobject /* this */, jlong session) { auto *session_cache = reinterpret_cast<SessionCache *>(session); delete session_cache->inference_session; delete session_cache; }
d. 执行训练
对于需要训练的每个批次,都会调用此函数。训练循环在 Kotlin 的应用程序端编写,并且在训练循环中,对于每个批次,都会调用
performTraining
函数。performTraining
的参数是session
:long
,表示SessionCache
对象。batch
:要传入以进行训练的输入图像,以浮点数组形式表示。labels
:与为训练提供的输入图像关联的标签,以整数数组形式表示。batch_size
:每个TrainStep
处理的图像数量。channels
:图像中的通道数。对于我们的示例,始终会使用值3
调用此值。frame_rows
:图像中的行数。对于我们的示例,始终会使用值224
调用此值。frame_cols
:图像中的列数。对于我们的示例,始终会使用值224
调用此值。
该函数返回一个
float
,表示此批次的训练损失。extern "C" JNIEXPORT float JNICALL Java_com_example_ortpersonalize_MainActivity_performTraining( JNIEnv *env, jobject /* this */, jlong session, jfloatArray batch, jintArray labels, jint batch_size, jint channels, jint frame_rows, jint frame_cols) { auto* session_cache = reinterpret_cast<SessionCache *>(session); if (session_cache->inference_session) { // Invalidate the inference session since we will be updating the model parameters // in train_step. // The next call to inference session will need to recreate the inference session. delete session_cache->inference_session; session_cache->inference_session = nullptr; } // Update the model parameters using this batch of inputs. return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr), env->GetIntArrayElements(labels, nullptr), batch_size, channels, frame_rows, frame_cols); }
上面的函数利用了
train_step
函数。train_step
函数的定义如下所示namespace training { float train_step(SessionCache* session_cache, float *batches, int32_t *labels, int64_t batch_size, int64_t image_channels, int64_t image_rows, int64_t image_cols) { const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols}); const std::vector<int64_t> labels_shape({batch_size}); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector<Ort::Value> user_inputs; // {inputs, labels} // Inputs batched user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches, batch_size * image_channels * image_rows * image_cols * sizeof(float), input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); // Labels batched user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels, batch_size * sizeof(int32_t), labels_shape.data(), labels_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); // Run the train step and execute the forward + loss + backward. float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>()); // Update the model parameters by taking a step in the direction of the gradients computed above. session_cache->training_session.OptimizerStep(); // Reset the gradients now that the parameters have been updated. // New set of gradients can then be computed for the next round of inputs. session_cache->training_session.LazyResetGrad(); return loss; } } // namespace training
e. 执行推理
当用户想要执行推理时,会调用此函数。
performInference
的参数是session
:long
,表示SessionCache
对象。image_buffer
:要传入以进行训练的输入图像,以浮点数组形式表示。batch_size
:每次推理处理的图像数量。对于我们的示例,始终会使用值1
调用此值。image_channels
:图像中的通道数。对于我们的示例,始终会使用值3
调用此值。image_rows
:图像中的行数。对于我们的示例,始终会使用值224
调用此值。image_cols
:图像中的列数。对于我们的示例,始终会使用值224
调用此值。classes
:表示所有四个自定义类别的字符串列表。
该函数返回一个
string
,表示提供的四个自定义类别之一。这是模型的预测。extern "C" JNIEXPORT jstring JNICALL Java_com_example_ortpersonalize_MainActivity_performInference( JNIEnv *env, jobject /* this */, jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows, jint image_cols, jobjectArray classes) { std::vector<std::string> classes_str; for (int i = 0; i < env->GetArrayLength(classes); ++i) { // Access the current string element jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i)); classes_str.push_back(utils::JString2String(env, elem)); } auto* session_cache = reinterpret_cast<SessionCache *>(session); if (!session_cache->inference_session) { // The inference session does not exist, so create a new one. session_cache->training_session.ExportModelForInferencing( session_cache->artifact_paths.inference_model_path.c_str(), {"output"}); session_cache->inference_session = std::make_unique<Ort::Session>( session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(), session_cache->session_options).release(); } auto prediction = inference::classify( session_cache, env->GetFloatArrayElements(image_buffer, nullptr), batch_size, image_channels, image_rows, image_cols, classes_str); return env->NewStringUTF(prediction.first.c_str()); }
上面的函数调用
classify
。classify 的定义是namespace inference { std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data, int64_t batch_size, int64_t image_channels, int64_t image_rows, int64_t image_cols, const std::vector<std::string>& classes) { std::vector<const char *> input_names = {"input"}; size_t input_count = 1; std::vector<const char *> output_names = {"output"}; size_t output_count = 1; std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols}); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::vector<Ort::Value> input_values; // {input images} input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data, batch_size * image_channels * image_rows * image_cols * sizeof(float), input_shape.data(), input_shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); std::vector<Ort::Value> output_values; output_values.emplace_back(nullptr); // get the logits session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(), input_count, output_names.data(), output_values.data(), output_count); float *output = output_values.front().GetTensorMutableData<float>(); // run softmax and get the probabilities of each class std::vector<float> probabilities = Softmax(output, classes.size()); size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end())); return {classes[best_index], probabilities[best_index]}; } } // namespace inference
classify 函数调用另一个名为
Softmax
的函数。Softmax
的定义是std::vector<float> Softmax(float *logits, size_t num_logits) { std::vector<float> probabilities(num_logits, 0); float sum = 0; for (size_t i = 0; i < num_logits; ++i) { probabilities[i] = exp(logits[i]); sum += probabilities[i]; } if (sum != 0.0f) { for (size_t i = 0; i < num_logits; ++i) { probabilities[i] /= sum; } } return probabilities; }
-
a.
MobileNetV2
模型期望提供的输入图像是- 大小为
3 x 224 x 224
。 - 归一化图像,减去均值
(0.485, 0.456, 0.406)
并除以标准差(0.229, 0.224, 0.225)
此预处理在 Java/Kotlin 中使用 Android 提供的库完成。
让我们在
app/src/main/java/com/example/ortpersonalize
目录下创建一个名为ImageProcessingUtil.kt
的新文件。我们将在此文件中添加用于裁剪和调整大小以及归一化图像的实用程序方法。b. 裁剪和调整图像大小。
fun processBitmap(bitmap: Bitmap) : Bitmap { // This function processes the given bitmap by // - cropping along the longer dimension to get a square bitmap // If the width is larger than the height // ___+_________________+___ // | + + | // | + + | // | + + + | // | + + | // |__+_________________+__| // <-------- width --------> // <----- height ----> // <--> cropped <--> // // If the height is larger than the width // _________________________ ʌ ʌ // | | | cropped // |+++++++++++++++++++++++| | ʌ v // | | | | // | | | | // | + | height width // | | | | // | | | | // |+++++++++++++++++++++++| | v ʌ // | | | cropped // |_______________________| v v // // // // - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the // mobilenetv2 model. lateinit var bitmapCropped: Bitmap if (bitmap.getWidth() >= bitmap.getHeight()) { // Since height is smaller than the width, we crop a square whose length is the height // So cropping happens along the width dimesion val width: Int = bitmap.getHeight() val height: Int = bitmap.getHeight() // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2) // so that the cropped width contains equal portion of the width on either side of center // top side of the cropped image must begin at 0 since we are not cropping along the height // dimension val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2 val y: Int = 0 bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height) } else { // Since width is smaller than the height, we crop a square whose length is the width // So cropping happens along the height dimesion val width: Int = bitmap.getWidth() val height: Int = bitmap.getWidth() // left side of the cropped image must begin at 0 since we are not cropping along the width // dimension // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2) // so that the cropped height contains equal portion of the height on either side of center val x: Int = 0 val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2 bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height) } // Resize the image to be channels x width x height as needed by the mobilenetv2 model val width: Int = 224 val height: Int = 224 val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false) return bitmapResized }
c. 归一化图像。
fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) { // This function iterates over the image and performs the following // on the image pixels // - normalizes the pixel values to be between 0 and 1 // - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration) // from the pixel values // - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the // mobilenetv2 model configuration) // Values are written to the given buffer starting at the provided offset. // Values are written as follows // |____|____________________|__________________| <--- buffer // ʌ <--- offset // ʌ <--- offset + width * height * channels // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order // |____|______|gggggg|______|__________________| <--- green channel read in column major order // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order val width: Int = bitmap.getWidth() val height: Int = bitmap.getHeight() val stride: Int = width * height for (x in 0 until width) { for (y in 0 until height) { val color: Int = bitmap.getPixel(y, x) val index = offset + (x * height + y) // Subtract the mean and divide by the standard deviation // Values for mean and standard deviation used for // the movilenetv2 model. buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f) buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f) buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f) } } }
d. 从 Uri 获取 Bitmap
fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap { // This function reads the image file at the given uri and decodes it to a bitmap val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri) return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true) }
- 大小为
-
a. 对于本教程,我们将使用以下用户界面元素
- Train 和 Infer 按钮
- 类别按钮
- 状态消息文本
- 图像显示
- 进度对话框
b. 本教程不打算展示如何创建图形用户界面。因此,我们将简单地重用 GitHub 上提供的文件。
c. 将所有字符串定义从
strings.xml
复制到 Android Studio 本地的strings.xml
。d. 将内容从
activity_main.xml
复制到 Android Studio 本地的activity_main.xml
。e. 在
layout
文件夹下创建一个名为dialog.xml
的新文件。将内容从dialog.xml
复制到新创建的 Android Studio 本地的dialog.xml
。f. 本节的其余更改需要在 MainActivity.kt 文件中进行。
g. 启动应用程序
当应用程序启动时,会调用
onCreate
函数。此函数负责设置会话缓存和用户界面处理程序。有关代码,请参阅
MainActivity.kt
文件中的onCreate
函数。h. 自定义类别按钮处理程序 - 我们希望使用类别按钮让用户选择其自定义图像进行训练。我们需要为这些按钮添加侦听器才能执行此操作。这些侦听器将完全做到这一点。
请参阅
MainActivity.kt
中的这些按钮处理程序- onClassAClickedListener
- onClassBClickedListener
- onClassXClickedListener
- onClassYClickedListener
i. 个性化自定义类别标签
默认情况下,自定义类别标签为
[A, B, X, Y]
。但是,为了清晰起见,让我们允许用户重命名这些标签。这通过长按侦听器实现,即 (MainActivity.kt
中定义的)- onClassALongClickedListener
- onClassBLongClickedListener
- onClassXLongClickedListener
- onClassYLongClickedListener
j. 切换自定义类别。
当自定义类别切换关闭时,将运行预打包的动物数据集。当它打开时,用户应自带数据集进行训练。为了处理此转换,
MainActivity.kt
中实现了onCustomClassSettingChangedListener
开关处理程序。k. 训练处理程序
当每个类别至少有 1 张图像时,可以启用
Train
按钮。当单击Train
按钮时,将针对选定的图像开始训练。训练处理程序负责- 将训练图像收集到一个容器中。
- 打乱图像的顺序。
- 裁剪和调整图像大小。
- 归一化图像。
- 批处理图像。
- 执行训练循环(循环调用 C++
performTraining
函数)。
MainActivity.kt
中定义的onTrainButtonClickedListener
函数执行上述操作。l. 推理处理程序
训练完成后,用户可以单击
Infer
按钮来推理任何图像。推理处理程序负责- 收集推理图像。
- 裁剪和调整图像大小。
- 归一化图像。
- 调用 C++
performInference
函数。 - 向用户界面报告推断的输出。
这是通过
MainActivity.kt
中的onInferenceButtonClickedListener
函数实现的。m. 上述所有活动的处理程序
一旦选择了用于推理或自定义类别的图像,就需要对其进行处理。
MainActivity.kt
中定义的onActivityResult
函数执行此操作。n. 最后一件事。在
AndroidManifest.xml
文件中添加以下内容以使用摄像头<uses-permission android:name="android.permission.CAMERA" /> <uses-feature android:name="android.hardware.camera" />
训练阶段 - 在设备上运行应用程序
-
a. 让我们将 Android 设备连接到计算机并在设备上运行应用程序。
b. 在设备上启动应用程序应如下所示
-
a. 让我们开始使用预加载的动物设备进行训练,方法是在设备上启动应用程序。
b. 切换底部的
Custom classes
开关。c. 类别标签将更改为
Dog
、Cat
、Elephant
和Cow
。d. 运行
Training
并等待进度对话框消失(训练完成后)。e. 现在使用库中的任何动物图像进行推理。
从上图可以看出,模型正确预测了
Cow
。 -
a. 从 Web 下载 Tom Cruise、Leonardo DiCaprio、Ryan Reynolds 和 Brad Pitt 的图像。
b. 通过关闭应用程序并重新启动应用程序,确保启动应用程序的新会话。
c. 应用程序启动后,使用长按将四个类别分别重命名为
Tom
、Leo
、Ryan
、Brad
。d. 单击每个类别的按钮,并选择与该名人关联的图像。每个类别可以使用大约 10~15 张图像。
e. 点击
Train
按钮,让应用程序从提供的数据中学习。f. 训练完成后,我们可以点击
Infer
按钮并提供应用程序尚未见过的图像。g. 完成!希望应用程序正确分类了图像。
结论
恭喜!您已成功构建一个 Android 应用程序,该应用程序学习使用设备上的 ONNX Runtime 对图像进行分类。该应用程序也在 GitHub 上提供,网址为 onnxruntime-training-examples
。