设备端训练:构建 Android 应用程序

在本教程中,我们将探讨如何构建一个集成 ONNX Runtime 设备端训练解决方案的 Android 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型的过程,而无需依赖云服务或外部服务器。

以下是本教程结束时应用程序的外观

an image classification app with Tom Cruise in the middle.

简介

我们将指导您完成创建 Android 应用程序的步骤,该应用程序可以使用设备端训练技术训练简单的图像分类模型。本教程展示了 迁移学习 技术,其中从在一个任务上训练模型获得的知识被用来提高模型在不同但相关的任务上的性能。迁移学习不是从头开始学习过程,而是允许我们将预训练模型学习到的知识或特征迁移到新任务。

在本教程中,我们将利用 MobileNetV2 模型,该模型已在 ImageNet 等大规模图像数据集(具有 1,000 个类别)上进行了训练。我们将使用此模型将自定义数据分类为四个类别之一。MobileNetV2 的初始层充当特征提取器,捕获适用于各种任务的通用视觉特征,只有最终的分类器层将针对手头的任务进行训练。

在本教程中,我们将使用数据学习

  • 使用预先打包的动物数据集将动物分类为四个类别之一。
  • 使用自定义名人数据集将名人分类为四个类别之一。

目录

先决条件

要学习本教程,您应该基本了解如何使用 Java 或 Kotlin 开发 Android 应用程序。熟悉 C++ 以及机器学习概念(如神经网络和图像分类)也将有所帮助。

  • 用于准备训练工件的 Python 开发环境
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • 具有摄像头的 Android 设备,处于开发者模式并启用 USB 调试

注意 整个 Android 应用程序也在 onnxruntime-training-examples GitHub 存储库上提供。

离线阶段 - 构建训练工件

  1. 将模型导出到 ONNX。

    我们从预训练的 PyTorch 模型开始,并将其导出到 ONNX。MobileNetV2 模型已在 imagenet 数据集上预训练,该数据集具有 1000 个类别的数据。对于我们的图像分类任务,我们只想将图像分类为 4 个类别。因此,我们将模型的最后一层更改为输出 4 个 logits 而不是 1,000 个。

    有关如何将 PyTorch 模型导出到 ONNX 的更多详细信息,请参见此处

    import torch
    import torchvision
    
    model = torchvision.models.mobilenet_v2(
       weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    
    # The original model is trained on imagenet which has 1000 classes.
    # For our image classification scenario, we need to classify among 4 categories.
    # So we need to change the last layer of the model to have 4 outputs.
    model.classifier[1] = torch.nn.Linear(1280, 4)
    
    # Export the model to ONNX.
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                      f"training_artifacts/{model_name}.onnx",
                      input_names=["input"], output_names=["output"],
                      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定义可训练和不可训练的参数

    import onnx
    
    # Load the onnx model.
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    
    # Define the parameters that require their gradients to be computed
    # (trainable parameters) and those that do not (frozen/non trainable parameters).
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
       param.name
       for param in onnx_model.graph.initializer
       if param.name not in requires_grad
    ]
    
  3. 生成训练工件。

    在本教程中,我们将使用 CrossEntropyLoss 损失和 AdamW 优化器。有关工件生成的更多详细信息,请参见此处

    from onnxruntime.training import artifacts
    
    # Generate the training artifacts.
    artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts"
    )
    

    就这样!训练工件已在 training_artifacts 文件夹中生成。这标志着离线阶段的结束。这些工件已准备好部署到 Android 设备进行训练。

训练阶段 - Android 应用程序开发

  1. 在 Android Studio 中设置项目

    a. 打开 Android Studio 并单击 New Project Android Studio Setup - New Project

    b. 单击 Native C++ -> Next。填写 New Project 详细信息,如下所示

    • 名称 - ORT Personalize
    • 包名 - com.example.ortpersonalize
    • 语言 - Kotlin

    单击 Next

    Android Studio Setup - Project Name

    c. 选择 C++17 工具链 -> Finish

    Android Studio Setup - Project C++ ToolChain

    d. 完成!Android Studio 项目已设置完成。您现在应该能够看到带有某些样板代码的 Android Studio 编辑器。

  2. 添加 ONNX Runtime 依赖项

    a. 在 Android Studio 项目中的 cpp 目录下创建两个名为 libinclude\onnxruntime 的新文件夹。

    lib and include folder

    b. 前往 Maven Central。转到 Versions->Browse-> 并下载 onnxruntime-training-android 存档包(aar 文件)。

    c. 将 aar 扩展名重命名为 zip。因此 onnxruntime-training-android-1.15.0.aar 变为 onnxruntime-training-android-1.15.0.zip

    d. 解压缩 zip 文件的内容。

    e. 将 libonnxruntime.so 共享库从 jni\arm64-v8a 文件夹复制到 Android 项目中新创建的 lib 文件夹下。

    f. 将 headers 文件夹的内容复制到新创建的 include\onnxruntime 文件夹。

    g. 在 native-lib.cpp 文件中,包含训练 cxx 头文件。

    #include "onnxruntime_training_cxx_api.h"
    

    h. 将 abiFilters 添加到 build.gradle (Module) 文件,以便选择 arm64-v8a。此设置必须在 build.gradle 中的 defaultConfig 下添加

    ndk {
       abiFilters 'arm64-v8a'
    }
    

    请注意,build.gradle 文件的 defaultConfig 部分应如下所示

     defaultConfig {
        applicationId "com.example.ortpersonalize"
        minSdk 29
        targetSdk 33
        versionCode 1
        versionName "1.0"
    
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
           cmake {
                 cppFlags '-std=c++17'
           }
        }
    +   ndk {
    +       abiFilters 'arm64-v8a'
    +   }
       
     }
    

    i. 将 onnxruntime 共享库添加到 CMakeLists.txt,以便 cmake 可以找到并针对共享库进行构建。为此,在 CMakeLists.txt 中添加 ortpersonalize 库后,添加以下行

    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    

    通过在上面两行之后添加此行,让 CMake 知道 ONNX Runtime 头文件可以在哪里找到

    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    

    通过将 onnxruntime 库添加到 target_link_libraries,将 Android C++ 项目链接到 onnxruntime

    target_link_libraries( # Specifies the target library.
         ortpersonalize
    
         # Links the target library to the log library
         # included in the NDK.
         ${log-lib}
    
         onnxruntime)
    

    请注意,CMakeLists.txt 文件应如下所示

    project("ortpersonalize")
    
    add_library( # Sets the name of the library.
          ortpersonalize
    
          # Sets the library as a shared library.
          SHARED
    
          # Provides a relative path to your source file(s).
          native-lib.cpp
    +     utils.cpp
    +     inference.cpp
    +     train.cpp)
    + add_library(onnxruntime SHARED IMPORTED)
    + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    
    find_library( # Sets the name of the path variable.
          log-lib
    
          # Specifies the name of the NDK library that
          # you want CMake to locate.
          log)
    
    target_link_libraries( # Specifies the target library.
          ortpersonalize
    
          # Links the target library to the log library
          # included in the NDK.
          ${log-lib}
    +     onnxruntime)
    
    

    j. 构建应用程序并等待成功,以确认应用程序已包含 ONNX Runtime 头文件,并且可以成功链接到共享 onnxruntime 库。

  3. 打包预构建的训练工件和数据集

    a. 通过右键单击 app -> New -> Folder -> Assets Folder 并将其放在 main 下,在 Android Studio 项目的左侧窗格中创建一个新的 assets 文件夹。

    b. 将步骤 2 中生成的训练工件复制到此文件夹。

    c. 现在,前往 onnxruntime-training-examples 存储库并下载数据集 (images.zip) 到您的计算机并解压缩。此数据集是从 Kaggle 上提供的原始 animals-10 数据集修改而来的,由 Corrado Alessio 创建。

    d. 将下载的 images 文件夹复制到 Android Studio 中的 assets/images 目录。

    项目的左侧窗格应如下所示

    Project Assets

  4. 与 ONNX Runtime 交互 - C++ 代码

    a. 我们将在 C++ 中实现以下四个函数,这些函数将从应用程序中调用

    • createSession:将在应用程序启动时调用。它将创建新的 CheckpointStateTrainingSession 对象。
    • releaseSession:将在应用程序即将关闭时调用。此函数将释放应用程序启动时分配的资源。
    • performTraining:将在用户单击 UI 上的 Train 按钮时调用。
    • performInference:将在用户单击 UI 上的 Infer 按钮时调用。

    b. 创建会话

    此函数在应用程序启动时调用。这将使用训练工件 assets 来创建 C++ CheckpointState 和 TrainingSession 对象。这些对象将用于在设备上训练模型。

    createSession 的参数是

    • checkpoint_path:检查点工件的缓存路径。
    • train_model_path:训练模型工件的缓存路径。
    • eval_model_path:评估模型工件的缓存路径。
    • optimizer_model_path:优化器模型工件的缓存路径。
    • cache_dir_path:Android 设备上缓存目录的路径。缓存目录用作从 C++ 代码访问训练工件的方式。

    该函数返回一个 long,表示指向 session_cache 对象的指针。每当我们需要访问训练会话时,此 long 都可以强制转换为 SessionCache

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
          JNIEnv *env, jobject /* this */,
          jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
          jstring optimizer_model_path, jstring cache_dir_path)
    {
       std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
                utils::JString2String(env, checkpoint_path),
                utils::JString2String(env, train_model_path),
                utils::JString2String(env, eval_model_path),
                utils::JString2String(env, optimizer_model_path),
                utils::JString2String(env, cache_dir_path));
       return reinterpret_cast<long>(session_cache.release());
    }
    

    从上面的函数体可以看出,此函数创建一个指向 SessionCache 类对象的唯一指针。SessionCache 的定义如下所示。

    struct SessionCache {
       ArtifactPaths artifact_paths;
       Ort::Env ort_env;
       Ort::SessionOptions session_options;
       Ort::CheckpointState checkpoint_state;
       Ort::TrainingSession training_session;
       Ort::Session* inference_session;
    
       SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
                   const std::string &eval_model_path, const std::string &optimizer_model_path,
                   const std::string& cache_dir_path) :
       artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
       ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
       checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
       training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
                         artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
       inference_session(nullptr) {}
    };
    

    ArtifactPaths 的定义是

    struct ArtifactPaths {
       std::string checkpoint_path;
       std::string training_model_path;
       std::string eval_model_path;
       std::string optimizer_model_path;
       std::string cache_dir_path;
       std::string inference_model_path;
    
       ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
                      const std::string &eval_model_path, const std::string &optimizer_model_path,
                      const std::string& cache_dir_path) :
       checkpoint_path(checkpoint_path), training_model_path(training_model_path),
       eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
       cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };
    

    c. 释放会话

    此函数在应用程序即将关闭时调用。它释放应用程序启动时创建的资源,主要是 CheckpointState 和 TrainingSession。

    releaseSession 的参数是

    • sessionlong,表示 SessionCache 对象。
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
          JNIEnv *env, jobject /* this */,
          jlong session) {
       auto *session_cache = reinterpret_cast<SessionCache *>(session);
       delete session_cache->inference_session;
       delete session_cache;
    }
    

    d. 执行训练

    对于需要训练的每个批次,都会调用此函数。训练循环在 Kotlin 的应用程序端编写,并且在训练循环中,对于每个批次,都会调用 performTraining 函数。

    performTraining 的参数是

    • sessionlong,表示 SessionCache 对象。
    • batch:要传入以进行训练的输入图像,以浮点数组形式表示。
    • labels:与为训练提供的输入图像关联的标签,以整数数组形式表示。
    • batch_size:每个 TrainStep 处理的图像数量。
    • channels:图像中的通道数。对于我们的示例,始终会使用值 3 调用此值。
    • frame_rows:图像中的行数。对于我们的示例,始终会使用值 224 调用此值。
    • frame_cols:图像中的列数。对于我们的示例,始终会使用值 224 调用此值。

    该函数返回一个 float,表示此批次的训练损失。

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
          JNIEnv *env, jobject /* this */,
          jlong session, jfloatArray batch, jintArray labels, jint batch_size,
          jint channels, jint frame_rows, jint frame_cols) {
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
    
       if (session_cache->inference_session) {
          // Invalidate the inference session since we will be updating the model parameters
          // in train_step.
          // The next call to inference session will need to recreate the inference session.
          delete session_cache->inference_session;
          session_cache->inference_session = nullptr;
       }
    
       // Update the model parameters using this batch of inputs.
       return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
                                  env->GetIntArrayElements(labels, nullptr), batch_size,
                                  channels, frame_rows, frame_cols);
    }
    

    上面的函数利用了 train_step 函数。train_step 函数的定义如下所示

    namespace training {
    
       float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
                         int64_t batch_size, int64_t image_channels, int64_t image_rows,
                         int64_t image_cols) {
          const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
          const std::vector<int64_t> labels_shape({batch_size});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> user_inputs; // {inputs, labels}
          // Inputs batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
          // Labels batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
                                                             batch_size * sizeof(int32_t),
                                                             labels_shape.data(), labels_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    
          // Run the train step and execute the forward + loss + backward.
          float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    
          // Update the model parameters by taking a step in the direction of the gradients computed above.
          session_cache->training_session.OptimizerStep();
    
          // Reset the gradients now that the parameters have been updated.
          // New set of gradients can then be computed for the next round of inputs.
          session_cache->training_session.LazyResetGrad();
    
          return loss;
       }
    
    } // namespace training
    

    e. 执行推理

    当用户想要执行推理时,会调用此函数。

    performInference 的参数是

    • sessionlong,表示 SessionCache 对象。
    • image_buffer:要传入以进行训练的输入图像,以浮点数组形式表示。
    • batch_size:每次推理处理的图像数量。对于我们的示例,始终会使用值 1 调用此值。
    • image_channels:图像中的通道数。对于我们的示例,始终会使用值 3 调用此值。
    • image_rows:图像中的行数。对于我们的示例,始终会使用值 224 调用此值。
    • image_cols:图像中的列数。对于我们的示例,始终会使用值 224 调用此值。
    • classes:表示所有四个自定义类别的字符串列表。

    该函数返回一个 string,表示提供的四个自定义类别之一。这是模型的预测。

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
          JNIEnv *env, jobject  /* this */,
          jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
          jint image_cols, jobjectArray classes) {
    
       std::vector<std::string> classes_str;
       for (int i = 0; i < env->GetArrayLength(classes); ++i) {
          // Access the current string element
          jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
          classes_str.push_back(utils::JString2String(env, elem));
       }
    
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
       if (!session_cache->inference_session) {
          // The inference session does not exist, so create a new one.
          session_cache->training_session.ExportModelForInferencing(
                   session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
          session_cache->inference_session = std::make_unique<Ort::Session>(
                   session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
                   session_cache->session_options).release();
       }
    
       auto prediction = inference::classify(
                session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
                batch_size, image_channels, image_rows, image_cols, classes_str);
    
       return env->NewStringUTF(prediction.first.c_str());
    }
    

    上面的函数调用 classify。classify 的定义是

    namespace inference {
    
       std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
                                              int64_t batch_size, int64_t image_channels,
                                              int64_t image_rows, int64_t image_cols,
                                              const std::vector<std::string>& classes) {
          std::vector<const char *> input_names = {"input"};
          size_t input_count = 1;
    
          std::vector<const char *> output_names = {"output"};
          size_t output_count = 1;
    
          std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> input_values; // {input images}
          input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
    
          std::vector<Ort::Value> output_values;
          output_values.emplace_back(nullptr);
    
          // get the logits
          session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
                                                 input_count, output_names.data(), output_values.data(), output_count);
    
          float *output = output_values.front().GetTensorMutableData<float>();
    
          // run softmax and get the probabilities of each class
          std::vector<float> probabilities = Softmax(output, classes.size());
          size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    
          return {classes[best_index], probabilities[best_index]};
       }
    
    } // namespace inference
    

    classify 函数调用另一个名为 Softmax 的函数。Softmax 的定义是

    std::vector<float> Softmax(float *logits, size_t num_logits) {
       std::vector<float> probabilities(num_logits, 0);
       float sum = 0;
       for (size_t i = 0; i < num_logits; ++i) {
             probabilities[i] = exp(logits[i]);
             sum += probabilities[i];
       }
    
       if (sum != 0.0f) {
             for (size_t i = 0; i < num_logits; ++i) {
                probabilities[i] /= sum;
             }
       }
    
       return probabilities;
    }
    
  5. 图像预处理

    a. MobileNetV2 模型期望提供的输入图像是

    • 大小为 3 x 224 x 224
    • 归一化图像,减去均值 (0.485, 0.456, 0.406) 并除以标准差 (0.229, 0.224, 0.225)

    此预处理在 Java/Kotlin 中使用 Android 提供的库完成。

    让我们在 app/src/main/java/com/example/ortpersonalize 目录下创建一个名为 ImageProcessingUtil.kt 的新文件。我们将在此文件中添加用于裁剪和调整大小以及归一化图像的实用程序方法。

    b. 裁剪和调整图像大小。

    fun processBitmap(bitmap: Bitmap) : Bitmap {
       // This function processes the given bitmap by
       //   - cropping along the longer dimension to get a square bitmap
       //     If the width is larger than the height
       //     ___+_________________+___
       //     |  +                 +  |
       //     |  +                 +  |
       //     |  +        +        +  |
       //     |  +                 +  |
       //     |__+_________________+__|
       //     <-------- width -------->
       //        <----- height ---->
       //     <-->      cropped    <-->
       //
       //     If the height is larger than the width
       //     _________________________   ʌ            ʌ
       //     |                       |   |         cropped
       //     |+++++++++++++++++++++++|   |      ʌ     v
       //     |                       |   |      |
       //     |                       |   |      |
       //     |           +           | height width
       //     |                       |   |      |
       //     |                       |   |      |
       //     |+++++++++++++++++++++++|   |      v     ʌ
       //     |                       |   |         cropped
       //     |_______________________|   v            v
       //
       //
       //
       //   - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the
       //     mobilenetv2 model.
       lateinit var bitmapCropped: Bitmap
       if (bitmap.getWidth() >= bitmap.getHeight()) {
          // Since height is smaller than the width, we crop a square whose length is the height
          // So cropping happens along the width dimesion
          val width: Int = bitmap.getHeight()
          val height: Int = bitmap.getHeight()
    
          // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)
          // so that the cropped width contains equal portion of the width on either side of center
          // top side of the cropped image must begin at 0 since we are not cropping along the height
          // dimension
          val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
          val y: Int = 0
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       } else {
          // Since width is smaller than the height, we crop a square whose length is the width
          // So cropping happens along the height dimesion
          val width: Int = bitmap.getWidth()
          val height: Int = bitmap.getWidth()
    
          // left side of the cropped image must begin at 0 since we are not cropping along the width
          // dimension
          // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)
          // so that the cropped height contains equal portion of the height on either side of center
          val x: Int = 0
          val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       }
    
       // Resize the image to be channels x width x height as needed by the mobilenetv2 model
       val width: Int = 224
       val height: Int = 224
       val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    
       return bitmapResized
    }
    

    c. 归一化图像。

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
       // This function iterates over the image and performs the following
       // on the image pixels
       //   - normalizes the pixel values to be between 0 and 1
       //   - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration)
       //     from the pixel values
       //   - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the
       //     mobilenetv2 model configuration)
       // Values are written to the given buffer starting at the provided offset.
       // Values are written as follows
       // |____|____________________|__________________| <--- buffer
       //      ʌ                                         <--- offset
       //                           ʌ                    <--- offset + width * height * channels
       // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
       // |____|______|gggggg|______|__________________| <--- green channel read in column major order
       // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    
       val width: Int = bitmap.getWidth()
       val height: Int = bitmap.getHeight()
       val stride: Int = width * height
    
       for (x in 0 until width) {
          for (y in 0 until height) {
                val color: Int = bitmap.getPixel(y, x)
                val index = offset + (x * height + y)
    
                // Subtract the mean and divide by the standard deviation
                // Values for mean and standard deviation used for
                // the movilenetv2 model.
                buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
                buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
                buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
          }
       }
    }
    

    d. 从 Uri 获取 Bitmap

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
       // This function reads the image file at the given uri and decodes it to a bitmap
       val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
       return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
    
  6. 应用程序前端

    a. 对于本教程,我们将使用以下用户界面元素

    • Train 和 Infer 按钮
    • 类别按钮
    • 状态消息文本
    • 图像显示
    • 进度对话框

    b. 本教程不打算展示如何创建图形用户界面。因此,我们将简单地重用 GitHub 上提供的文件。

    c. 将所有字符串定义从 strings.xml 复制到 Android Studio 本地的 strings.xml

    d. 将内容从 activity_main.xml 复制到 Android Studio 本地的 activity_main.xml

    e. 在 layout 文件夹下创建一个名为 dialog.xml 的新文件。将内容从 dialog.xml 复制到新创建的 Android Studio 本地的 dialog.xml

    f. 本节的其余更改需要在 MainActivity.kt 文件中进行。

    g. 启动应用程序

    当应用程序启动时,会调用 onCreate 函数。此函数负责设置会话缓存和用户界面处理程序。

    有关代码,请参阅 MainActivity.kt 文件中的 onCreate 函数。

    h. 自定义类别按钮处理程序 - 我们希望使用类别按钮让用户选择其自定义图像进行训练。我们需要为这些按钮添加侦听器才能执行此操作。这些侦听器将完全做到这一点。

    请参阅 MainActivity.kt 中的这些按钮处理程序

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. 个性化自定义类别标签

    默认情况下,自定义类别标签为 [A, B, X, Y]。但是,为了清晰起见,让我们允许用户重命名这些标签。这通过长按侦听器实现,即 (MainActivity.kt 中定义的)

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. 切换自定义类别。

    当自定义类别切换关闭时,将运行预打包的动物数据集。当它打开时,用户应自带数据集进行训练。为了处理此转换,MainActivity.kt 中实现了 onCustomClassSettingChangedListener 开关处理程序。

    k. 训练处理程序

    当每个类别至少有 1 张图像时,可以启用 Train 按钮。当单击 Train 按钮时,将针对选定的图像开始训练。训练处理程序负责

    • 将训练图像收集到一个容器中。
    • 打乱图像的顺序。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 批处理图像。
    • 执行训练循环(循环调用 C++ performTraining 函数)。

    MainActivity.kt 中定义的 onTrainButtonClickedListener 函数执行上述操作。

    l. 推理处理程序

    训练完成后,用户可以单击 Infer 按钮来推理任何图像。推理处理程序负责

    • 收集推理图像。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 调用 C++ performInference 函数。
    • 向用户界面报告推断的输出。

    这是通过 MainActivity.kt 中的 onInferenceButtonClickedListener 函数实现的。

    m. 上述所有活动的处理程序

    一旦选择了用于推理或自定义类别的图像,就需要对其进行处理。MainActivity.kt 中定义的 onActivityResult 函数执行此操作。

    n. 最后一件事。在 AndroidManifest.xml 文件中添加以下内容以使用摄像头

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />
    

训练阶段 - 在设备上运行应用程序

  1. 在设备上运行应用程序

    a. 让我们将 Android 设备连接到计算机并在设备上运行应用程序。

    b. 在设备上启动应用程序应如下所示

    Barebones ORT Personalize app

  2. 使用预加载的数据集进行训练 - 动物

    a. 让我们开始使用预加载的动物设备进行训练,方法是在设备上启动应用程序。

    b. 切换底部的 Custom classes 开关。

    c. 类别标签将更改为 DogCatElephantCow

    d. 运行 Training 并等待进度对话框消失(训练完成后)。

    e. 现在使用库中的任何动物图像进行推理。

    ORT Personalize app with an image of a cow

    从上图可以看出,模型正确预测了 Cow

  3. 使用自定义数据集进行训练 - 名人

    a. 从 Web 下载 Tom Cruise、Leonardo DiCaprio、Ryan Reynolds 和 Brad Pitt 的图像。

    b. 通过关闭应用程序并重新启动应用程序,确保启动应用程序的新会话。

    c. 应用程序启动后,使用长按将四个类别分别重命名为 TomLeoRyanBrad

    d. 单击每个类别的按钮,并选择与该名人关联的图像。每个类别可以使用大约 10~15 张图像。

    e. 点击 Train 按钮,让应用程序从提供的数据中学习。

    f. 训练完成后,我们可以点击 Infer 按钮并提供应用程序尚未见过的图像。

    g. 完成!希望应用程序正确分类了图像。

    an image classification app with Tom Cruise in the middle.

结论

恭喜!您已成功构建一个 Android 应用程序,该应用程序学习使用设备上的 ONNX Runtime 对图像进行分类。该应用程序也在 GitHub 上提供,网址为 onnxruntime-training-examples