设备端训练:构建 Android 应用程序

在本教程中,我们将探讨如何构建一个集成 ONNX Runtime 设备端训练解决方案的 Android 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型,而不依赖于云服务或外部服务器的过程。

在本教程结束时,应用程序将看起来像这样

an image classification app with Tom Cruise in the middle.

介绍

我们将指导您完成创建 Android 应用程序的步骤,该应用程序可以使用设备端训练技术训练一个简单的图像分类模型。本教程展示了迁移学习技术,该技术利用在一个任务上训练模型获得的知识来提高模型在不同但相关任务上的性能。迁移学习允许我们将预训练模型学习到的知识或特征转移到新任务,而不是从头开始学习过程。

在本教程中,我们将利用MobileNetV2模型,该模型已在 ImageNet 等大型图像数据集(包含 1,000 个类别)上进行过预训练。我们将使用此模型将自定义数据分类到四个类别之一。MobileNetV2 的初始层作为特征提取器,捕获适用于各种任务的通用视觉特征,而只有最终的分类器层将针对手头的任务进行训练。

在本教程中,我们将使用数据来学习

  • 使用预打包的动物数据集将动物分类到四个类别之一。
  • 使用自定义名人数据集将名人分类到四个类别之一。

目录

先决条件

要遵循本教程,您应该对使用 Java 或 Kotlin 进行 Android 应用程序开发有基本了解。熟悉 C++ 以及熟悉机器学习概念(如神经网络和图像分类)也会有所帮助。

  • 用于准备训练工件的 Python 开发环境
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • 一台支持开发者模式并启用 USB 调试的带摄像头的 Android 设备

注意 完整的 Android 应用程序也可在 onnxruntime-training-examples GitHub 仓库中找到。

离线阶段 - 构建训练工件

  1. 将模型导出为 ONNX。

    我们从一个预训练的 PyTorch 模型开始,并将其导出到 ONNX。MobileNetV2 模型已在 imagenet 数据集上进行了预训练,该数据集包含 1000 个类别的数据。对于我们的图像分类任务,我们只想对 4 个类别的图像进行分类。因此,我们将模型的最后一层更改为输出 4 个 logits,而不是 1,000 个。

    有关如何将 PyTorch 模型导出到 ONNX 的更多详细信息,请参见此处

    import torch
    import torchvision
    
    model = torchvision.models.mobilenet_v2(
       weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    
    # The original model is trained on imagenet which has 1000 classes.
    # For our image classification scenario, we need to classify among 4 categories.
    # So we need to change the last layer of the model to have 4 outputs.
    model.classifier[1] = torch.nn.Linear(1280, 4)
    
    # Export the model to ONNX.
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                      f"training_artifacts/{model_name}.onnx",
                      input_names=["input"], output_names=["output"],
                      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定义可训练和不可训练的参数

    import onnx
    
    # Load the onnx model.
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    
    # Define the parameters that require their gradients to be computed
    # (trainable parameters) and those that do not (frozen/non trainable parameters).
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
       param.name
       for param in onnx_model.graph.initializer
       if param.name not in requires_grad
    ]
    
  3. 生成训练工件。

    在本教程中,我们将使用CrossEntropyLoss损失和AdamW优化器。有关工件生成的更多详细信息,请参见此处

    from onnxruntime.training import artifacts
    
    # Generate the training artifacts.
    artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts"
    )
    

    就这些!训练工件已生成到training_artifacts文件夹中。这标志着离线阶段的结束。这些工件已准备好部署到 Android 设备上进行训练。

训练阶段 - Android 应用程序开发

  1. 在 Android Studio 中设置项目

    a. 打开 Android Studio 并点击New Project Android Studio Setup - New Project

    b. 点击Native C++ -> Next。填写New Project详细信息如下

    • 名称 - ORT Personalize
    • 包名 - com.example.ortpersonalize
    • 语言 - Kotlin

    点击Next

    Android Studio Setup - Project Name

    c. 选择C++17工具链 -> Finish

    Android Studio Setup - Project C++ ToolChain

    d. 就这样!Android Studio 项目已设置完毕。您现在应该能看到带有一些样板代码的 Android Studio 编辑器。

  2. 添加 ONNX Runtime 依赖项

    a. 在 Android Studio 项目的 cpp 目录下创建两个新文件夹,分别命名为libinclude\onnxruntime

    lib and include folder

    b. 前往Maven Central。转到Versions->Browse-> 并下载onnxruntime-training-android归档包(aar 文件)。

    c. 将aar扩展名重命名为zip。例如,onnxruntime-training-android-1.15.0.aar变为onnxruntime-training-android-1.15.0.zip

    d. 解压 zip 文件内容。

    e. 将libonnxruntime.so共享库从jni\arm64-v8a文件夹复制到您的 Android 项目中新创建的lib文件夹下。

    f. 将headers文件夹的内容复制到新创建的include\onnxruntime文件夹中。

    g. 在native-lib.cpp文件中,包含训练 cxx 头文件。

    #include "onnxruntime_training_cxx_api.h"
    

    h. 将abiFilters添加到build.gradle (Module)文件中,以选择arm64-v8a。此设置必须添加到build.gradle中的defaultConfig下。

    ndk {
       abiFilters 'arm64-v8a'
    }
    

    请注意,build.gradle文件中的defaultConfig部分应如下所示

     defaultConfig {
        applicationId "com.example.ortpersonalize"
        minSdk 29
        targetSdk 33
        versionCode 1
        versionName "1.0"
    
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
           cmake {
                 cppFlags '-std=c++17'
           }
        }
    +   ndk {
    +       abiFilters 'arm64-v8a'
    +   }
       
     }
    

    i. 将onnxruntime共享库添加到CMakeLists.txt中,以便cmake能够找到并针对该共享库进行构建。为此,在CMakeLists.txt中添加ortpersonalize库之后添加以下几行

    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    

    通过紧接着上面两行添加以下行,让CMake知道 ONNX Runtime 头文件在哪里

    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    

    通过将onnxruntime库添加到target_link_libraries中,将 Android C++ 项目与onnxruntime库链接。

    target_link_libraries( # Specifies the target library.
         ortpersonalize
    
         # Links the target library to the log library
         # included in the NDK.
         ${log-lib}
    
         onnxruntime)
    

    请注意,CMakeLists.txt文件应如下所示

    project("ortpersonalize")
    
    add_library( # Sets the name of the library.
          ortpersonalize
    
          # Sets the library as a shared library.
          SHARED
    
          # Provides a relative path to your source file(s).
          native-lib.cpp
    +     utils.cpp
    +     inference.cpp
    +     train.cpp)
    + add_library(onnxruntime SHARED IMPORTED)
    + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    
    find_library( # Sets the name of the path variable.
          log-lib
    
          # Specifies the name of the NDK library that
          # you want CMake to locate.
          log)
    
    target_link_libraries( # Specifies the target library.
          ortpersonalize
    
          # Links the target library to the log library
          # included in the NDK.
          ${log-lib}
    +     onnxruntime)
    
    

    j. 构建应用程序并等待成功,以确认应用程序已包含 ONNX Runtime 头文件并能成功链接到 ONNX Runtime 共享库。

  3. 打包预构建的训练工件和数据集

    a. 在 Android Studio 项目左侧面板的app文件夹内,通过右键点击 app -> New -> Folder -> Assets Folder 创建一个新的assets文件夹,并将其放在 main 文件夹下。

    b. 将步骤 2 中生成的训练工件复制到此文件夹中。

    c. 现在,前往onnxruntime-training-examples仓库,将数据集(images.zip)下载到您的机器并解压。该数据集修改自 Kaggle 上由 Corrado Alessio 创建的原始 animals-10 数据集。

    d. 将下载的images文件夹复制到 Android Studio 中的assets/images目录下。

    项目左侧面板应如下所示

    Project Assets

  4. 与 ONNX Runtime 交互 - C++ 代码

    a. 我们将在 C++ 中实现以下四个将从应用程序调用的函数

    • createSession: 将在应用程序启动时调用。它将创建一个新的CheckpointStateTrainingSession对象。
    • releaseSession: 将在应用程序即将关闭时调用。此函数将释放应用程序启动时分配的资源。
    • performTraining: 将在用户点击 UI 上的Train按钮时调用。
    • performInference: 将在用户点击 UI 上的Infer按钮时调用。

    b. 创建 Session

    此函数在应用程序启动时调用。它将使用训练工件资源来创建 C++ CheckpointState 和 TrainingSession 对象。这些对象将用于在设备上训练模型。

    createSession的参数是

    • checkpoint_path: 检查点工件的缓存路径。
    • train_model_path: 训练模型工件的缓存路径。
    • eval_model_path: 评估模型工件的缓存路径。
    • optimizer_model_path: 优化器模型工件的缓存路径。
    • cache_dir_path: Android 设备上缓存目录的路径。缓存目录用作从 C++ 代码访问训练工件的一种方式。

    该函数返回一个long值,表示指向session_cache对象的指针。每当需要访问训练会话时,可以将此long值强制转换为SessionCache

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
          JNIEnv *env, jobject /* this */,
          jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
          jstring optimizer_model_path, jstring cache_dir_path)
    {
       std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
                utils::JString2String(env, checkpoint_path),
                utils::JString2String(env, train_model_path),
                utils::JString2String(env, eval_model_path),
                utils::JString2String(env, optimizer_model_path),
                utils::JString2String(env, cache_dir_path));
       return reinterpret_cast<long>(session_cache.release());
    }
    

    从上面的函数体可以看出,此函数创建指向SessionCache类对象的唯一指针。SessionCache的定义如下。

    struct SessionCache {
       ArtifactPaths artifact_paths;
       Ort::Env ort_env;
       Ort::SessionOptions session_options;
       Ort::CheckpointState checkpoint_state;
       Ort::TrainingSession training_session;
       Ort::Session* inference_session;
    
       SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
                   const std::string &eval_model_path, const std::string &optimizer_model_path,
                   const std::string& cache_dir_path) :
       artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
       ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
       checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
       training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
                         artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
       inference_session(nullptr) {}
    };
    

    ArtifactPaths的定义是

    struct ArtifactPaths {
       std::string checkpoint_path;
       std::string training_model_path;
       std::string eval_model_path;
       std::string optimizer_model_path;
       std::string cache_dir_path;
       std::string inference_model_path;
    
       ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
                      const std::string &eval_model_path, const std::string &optimizer_model_path,
                      const std::string& cache_dir_path) :
       checkpoint_path(checkpoint_path), training_model_path(training_model_path),
       eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
       cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };
    

    c. 释放 Session

    此函数在应用程序即将关闭时调用。它释放应用程序启动时创建的资源,主要是 CheckpointState 和 TrainingSession。

    releaseSession的参数是

    • session: 表示SessionCache对象的long值。
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
          JNIEnv *env, jobject /* this */,
          jlong session) {
       auto *session_cache = reinterpret_cast<SessionCache *>(session);
       delete session_cache->inference_session;
       delete session_cache;
    }
    

    d. 执行训练

    对于每个需要训练的批次,都会调用此函数。训练循环使用 Kotlin 在应用程序端编写,在训练循环中,performTraining函数会针对每个批次调用。

    performTraining的参数是

    • session: 表示SessionCache对象的long值。
    • batch: 作为浮点数组传递进来用于训练的输入图像。
    • labels: 与为训练提供的输入图像关联的整型数组标签。
    • batch_size: 每次TrainStep处理的图像数量。
    • channels: 图像中的通道数。对于我们的示例,此值始终为3
    • frame_rows: 图像中的行数。对于我们的示例,此值始终为224
    • frame_cols: 图像中的列数。对于我们的示例,此值始终为224

    该函数返回一个float值,表示此批次的训练损失。

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
          JNIEnv *env, jobject /* this */,
          jlong session, jfloatArray batch, jintArray labels, jint batch_size,
          jint channels, jint frame_rows, jint frame_cols) {
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
    
       if (session_cache->inference_session) {
          // Invalidate the inference session since we will be updating the model parameters
          // in train_step.
          // The next call to inference session will need to recreate the inference session.
          delete session_cache->inference_session;
          session_cache->inference_session = nullptr;
       }
    
       // Update the model parameters using this batch of inputs.
       return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
                                  env->GetIntArrayElements(labels, nullptr), batch_size,
                                  channels, frame_rows, frame_cols);
    }
    

    上面的函数利用了train_step函数。train_step函数的定义如下

    namespace training {
    
       float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
                         int64_t batch_size, int64_t image_channels, int64_t image_rows,
                         int64_t image_cols) {
          const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
          const std::vector<int64_t> labels_shape({batch_size});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> user_inputs; // {inputs, labels}
          // Inputs batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
          // Labels batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
                                                             batch_size * sizeof(int32_t),
                                                             labels_shape.data(), labels_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    
          // Run the train step and execute the forward + loss + backward.
          float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    
          // Update the model parameters by taking a step in the direction of the gradients computed above.
          session_cache->training_session.OptimizerStep();
    
          // Reset the gradients now that the parameters have been updated.
          // New set of gradients can then be computed for the next round of inputs.
          session_cache->training_session.LazyResetGrad();
    
          return loss;
       }
    
    } // namespace training
    

    e. 执行推理

    当用户想要执行推理时,将调用此函数。

    performInference的参数是

    • session: 表示SessionCache对象的long值。
    • image_buffer: 作为浮点数组传递进来用于训练的输入图像。
    • batch_size: 每次推理处理的图像数量。对于我们的示例,此值始终为1
    • image_channels: 图像中的通道数。对于我们的示例,此值始终为3
    • image_rows: 图像中的行数。对于我们的示例,此值始终为224
    • image_cols: 图像中的列数。对于我们的示例,此值始终为224
    • classes: 表示所有四个自定义类别的字符串列表。

    该函数返回一个string,表示提供的四个自定义类别之一。这是模型的预测结果。

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
          JNIEnv *env, jobject  /* this */,
          jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
          jint image_cols, jobjectArray classes) {
    
       std::vector<std::string> classes_str;
       for (int i = 0; i < env->GetArrayLength(classes); ++i) {
          // Access the current string element
          jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
          classes_str.push_back(utils::JString2String(env, elem));
       }
    
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
       if (!session_cache->inference_session) {
          // The inference session does not exist, so create a new one.
          session_cache->training_session.ExportModelForInferencing(
                   session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
          session_cache->inference_session = std::make_unique<Ort::Session>(
                   session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
                   session_cache->session_options).release();
       }
    
       auto prediction = inference::classify(
                session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
                batch_size, image_channels, image_rows, image_cols, classes_str);
    
       return env->NewStringUTF(prediction.first.c_str());
    }
    

    上面的函数调用classifyclassify的定义是

    namespace inference {
    
       std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
                                              int64_t batch_size, int64_t image_channels,
                                              int64_t image_rows, int64_t image_cols,
                                              const std::vector<std::string>& classes) {
          std::vector<const char *> input_names = {"input"};
          size_t input_count = 1;
    
          std::vector<const char *> output_names = {"output"};
          size_t output_count = 1;
    
          std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> input_values; // {input images}
          input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
    
          std::vector<Ort::Value> output_values;
          output_values.emplace_back(nullptr);
    
          // get the logits
          session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
                                                 input_count, output_names.data(), output_values.data(), output_count);
    
          float *output = output_values.front().GetTensorMutableData<float>();
    
          // run softmax and get the probabilities of each class
          std::vector<float> probabilities = Softmax(output, classes.size());
          size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    
          return {classes[best_index], probabilities[best_index]};
       }
    
    } // namespace inference
    

    classify函数调用另一个函数SoftmaxSoftmax的定义是

    std::vector<float> Softmax(float *logits, size_t num_logits) {
       std::vector<float> probabilities(num_logits, 0);
       float sum = 0;
       for (size_t i = 0; i < num_logits; ++i) {
             probabilities[i] = exp(logits[i]);
             sum += probabilities[i];
       }
    
       if (sum != 0.0f) {
             for (size_t i = 0; i < num_logits; ++i) {
                probabilities[i] /= sum;
             }
       }
    
       return probabilities;
    }
    
  5. 图像预处理

    a. MobileNetV2 模型期望提供的输入图像

    • 尺寸为3 x 224 x 224
    • 是减去均值(0.485, 0.456, 0.406)并除以标准差(0.229, 0.224, 0.225)的归一化图像

    这种预处理使用 Android 提供的库在 Java/Kotlin 中完成。

    让我们在app/src/main/java/com/example/ortpersonalize目录下创建一个名为ImageProcessingUtil.kt的新文件。我们将在此文件中添加用于裁剪、调整大小和归一化图像的实用方法。

    b. 裁剪和调整图像大小。

    fun processBitmap(bitmap: Bitmap) : Bitmap {
       // This function processes the given bitmap by
       //   - cropping along the longer dimension to get a square bitmap
       //     If the width is larger than the height
       //     ___+_________________+___
       //     |  +                 +  |
       //     |  +                 +  |
       //     |  +        +        +  |
       //     |  +                 +  |
       //     |__+_________________+__|
       //     <-------- width -------->
       //        <----- height ---->
       //     <-->      cropped    <-->
       //
       //     If the height is larger than the width
       //     _________________________   ʌ            ʌ
       //     |                       |   |         cropped
       //     |+++++++++++++++++++++++|   |      ʌ     v
       //     |                       |   |      |
       //     |                       |   |      |
       //     |           +           | height width
       //     |                       |   |      |
       //     |                       |   |      |
       //     |+++++++++++++++++++++++|   |      v     ʌ
       //     |                       |   |         cropped
       //     |_______________________|   v            v
       //
       //
       //
       //   - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the
       //     mobilenetv2 model.
       lateinit var bitmapCropped: Bitmap
       if (bitmap.getWidth() >= bitmap.getHeight()) {
          // Since height is smaller than the width, we crop a square whose length is the height
          // So cropping happens along the width dimesion
          val width: Int = bitmap.getHeight()
          val height: Int = bitmap.getHeight()
    
          // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)
          // so that the cropped width contains equal portion of the width on either side of center
          // top side of the cropped image must begin at 0 since we are not cropping along the height
          // dimension
          val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
          val y: Int = 0
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       } else {
          // Since width is smaller than the height, we crop a square whose length is the width
          // So cropping happens along the height dimesion
          val width: Int = bitmap.getWidth()
          val height: Int = bitmap.getWidth()
    
          // left side of the cropped image must begin at 0 since we are not cropping along the width
          // dimension
          // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)
          // so that the cropped height contains equal portion of the height on either side of center
          val x: Int = 0
          val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       }
    
       // Resize the image to be channels x width x height as needed by the mobilenetv2 model
       val width: Int = 224
       val height: Int = 224
       val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    
       return bitmapResized
    }
    

    c. 归一化图像。

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
       // This function iterates over the image and performs the following
       // on the image pixels
       //   - normalizes the pixel values to be between 0 and 1
       //   - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration)
       //     from the pixel values
       //   - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the
       //     mobilenetv2 model configuration)
       // Values are written to the given buffer starting at the provided offset.
       // Values are written as follows
       // |____|____________________|__________________| <--- buffer
       //      ʌ                                         <--- offset
       //                           ʌ                    <--- offset + width * height * channels
       // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
       // |____|______|gggggg|______|__________________| <--- green channel read in column major order
       // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    
       val width: Int = bitmap.getWidth()
       val height: Int = bitmap.getHeight()
       val stride: Int = width * height
    
       for (x in 0 until width) {
          for (y in 0 until height) {
                val color: Int = bitmap.getPixel(y, x)
                val index = offset + (x * height + y)
    
                // Subtract the mean and divide by the standard deviation
                // Values for mean and standard deviation used for
                // the movilenetv2 model.
                buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
                buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
                buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
          }
       }
    }
    

    d. 从 Uri 获取 Bitmap

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
       // This function reads the image file at the given uri and decodes it to a bitmap
       val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
       return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
    
  6. 应用程序前端

    a. 在本教程中,我们将使用以下用户界面元素

    • 训练和推理按钮
    • 类别按钮
    • 状态消息文本
    • 图像显示
    • 进度对话框

    b. 本教程不打算展示如何创建图形用户界面。因此,我们将直接重用 GitHub 上提供的文件。

    c. 将 strings.xml 中的所有字符串定义复制到您的 Android Studio 本地的strings.xml文件中。

    d. 将 activity_main.xml 中的内容复制到您的 Android Studio 本地的activity_main.xml文件中。

    e. 在layout文件夹下创建一个名为dialog.xml的新文件。将 dialog.xml 中的内容复制到您的 Android Studio 本地新创建的dialog.xml文件中。

    f. 本节的其余更改需要在 MainActivity.kt 文件中进行。

    g. 启动应用程序

    应用程序启动时,会调用onCreate函数。此函数负责设置会话缓存和用户界面处理程序。

    请参考MainActivity.kt文件中的onCreate函数获取代码。

    h. 自定义类别按钮处理程序 - 我们希望使用类别按钮供用户选择用于训练的自定义图像。我们需要为这些按钮添加监听器来完成此操作。这些监听器将完全实现这一点。

    请参考MainActivity.kt中的这些按钮处理程序

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. 个性化自定义类别标签

    默认情况下,自定义类别标签是[A, B, X, Y]。但是,让我们允许用户重命名这些标签以提高清晰度。这通过长按监听器实现,即(定义在MainActivity.kt中)

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. 切换自定义类别。

    当自定义类别开关关闭时,运行预打包的动物数据集。当开关打开时,用户需要提供自己的数据集进行训练。为了处理这种切换,onCustomClassSettingChangedListener开关处理程序在MainActivity.kt中实现。

    k. 训练处理程序

    当每个类别至少有一张图像时,Train按钮可以启用。当Train按钮被点击时,对选定的图像启动训练。训练处理程序负责

    • 将训练图像收集到一个容器中。
    • 打乱图像的顺序。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 对图像进行批处理。
    • 执行训练循环(在循环中调用 C++ performTraining 函数)。

    MainActivity.kt中定义的onTrainButtonClickedListener函数完成了上述操作。

    l. 推理处理程序

    训练完成后,用户可以点击Infer按钮来推理任何图像。推理处理程序负责

    • 收集推理图像。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 调用 C++ performInference 函数。
    • 将推理结果报告给用户界面。

    这由MainActivity.kt中的onInferenceButtonClickedListener函数实现。

    m. 上述所有活动的处理程序

    一旦图像被选定用于推理或自定义类别,就需要对其进行处理。MainActivity.kt中定义的onActivityResult函数完成了此操作。

    n. 最后一点。在AndroidManifest.xml文件中添加以下内容以使用摄像头

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />
    

训练阶段 - 在设备上运行应用程序

  1. 在设备上运行应用程序

    a. 让我们将 Android 设备连接到机器并在设备上运行应用程序。

    b. 在设备上启动应用程序应如下图所示

    Barebones ORT Personalize app

  2. 使用预加载数据集进行训练 - 动物

    a. 让我们通过在设备上启动应用程序,开始使用预加载的动物数据集进行训练。

    b. 切换底部的Custom classes开关。

    c. 类别标签将变为Dog, Cat, ElephantCow

    d. 运行Training并等待进度对话框消失(训练完成后)。

    e. 现在使用您图库中的任意动物图像进行推理。

    ORT Personalize app with an image of a cow

    从上图可以看出,模型正确预测为Cow

  3. 使用自定义数据集进行训练 - 名人

    a. 从网上下载汤姆·克鲁斯、莱昂纳多·迪卡普里奥、瑞安·雷诺兹和布拉德·皮特的图片。

    b. 请确保通过关闭并重新启动应用程序来启动一个新的应用会话。

    c. 应用程序启动后,通过长按将四个类别分别重命名为Tom, Leo, Ryan, Brad

    d. 点击每个类别对应的按钮,并选择与该名人相关的图片。每个类别可以使用大约 10~15 张图片。

    e. 点击Train按钮,让应用程序从提供的数据中学习。

    f. 训练完成后,我们可以点击Infer按钮并提供一张应用程序尚未见过的图片。

    g. 就这样!希望应用程序能够正确地对图片进行分类。

    an image classification app with Tom Cruise in the middle.

结论

恭喜!您已成功构建了一个使用 ONNX Runtime 在设备上学习图像分类的 Android 应用程序。该应用程序也可在 GitHub 上找到:onnxruntime-training-examples