设备端训练:构建 Android 应用程序

在本教程中,我们将探讨如何构建一个集成 ONNX Runtime 设备端训练解决方案的 Android 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型的过程,不依赖云服务或外部服务器。

本教程结束时,应用程序将如下图所示

an image classification app with Tom Cruise in the middle.

简介

我们将指导您完成创建 Android 应用程序的步骤,该应用程序可以使用设备端训练技术训练一个简单的图像分类模型。本教程展示了迁移学习技术,即利用从一项任务中训练模型获得的知识来提高模型在不同但相关任务上的性能。迁移学习允许我们将预训练模型学到的知识或特征迁移到新任务,而不是从头开始学习过程。

在本教程中,我们将利用 MobileNetV2 模型,该模型已在 ImageNet 等大型图像数据集(包含 1,000 个类别)上进行训练。我们将使用此模型将自定义数据分类到四个类别中的一个。MobileNetV2 的初始层作为特征提取器,捕获适用于各种任务的通用视觉特征,而只有最终的分类器层将针对手头的任务进行训练。

在本教程中,我们将使用数据来学习

  • 使用预打包的动物数据集将动物分类到四个类别中的一个。
  • 使用自定义名人数据集将名人分类到四个类别中的一个。

目录

先决条件

要学习本教程,您应该对使用 Java 或 Kotlin 进行 Android 应用程序开发有基本的了解。熟悉 C++ 以及机器学习概念(如神经网络和图像分类)也将有所帮助。

  • 用于准备训练工件的 Python 开发环境
  • Android Studio 4.1+
  • Android SDK 29+
  • Android NDK r21+
  • 一台已启用 开发者模式 并开启 USB 调试的 Android 设备,带摄像头

注意 整个 Android 应用程序也已在 onnxruntime-training-examples GitHub 仓库中提供。

离线阶段 - 构建训练工件

  1. 将模型导出为 ONNX。

    我们从预训练的 PyTorch 模型开始,并将其导出为 ONNX。MobileNetV2 模型已在 Imagenet 数据集上进行了预训练,该数据集包含 1000 个类别的数据。对于我们的图像分类任务,我们只想对 4 个类别中的图像进行分类。因此,我们将模型的最后一层更改为输出 4 个 logits,而不是 1,000 个。

    有关如何将 PyTorch 模型导出为 ONNX 的更多详细信息,请参见此处

    import torch
    import torchvision
    
    model = torchvision.models.mobilenet_v2(
       weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
    
    # The original model is trained on imagenet which has 1000 classes.
    # For our image classification scenario, we need to classify among 4 categories.
    # So we need to change the last layer of the model to have 4 outputs.
    model.classifier[1] = torch.nn.Linear(1280, 4)
    
    # Export the model to ONNX.
    model_name = "mobilenetv2"
    torch.onnx.export(model, torch.randn(1, 3, 224, 224),
                      f"training_artifacts/{model_name}.onnx",
                      input_names=["input"], output_names=["output"],
                      dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定义可训练和不可训练的参数

    import onnx
    
    # Load the onnx model.
    onnx_model = onnx.load(f"training_artifacts/{model_name}.onnx")
    
    # Define the parameters that require their gradients to be computed
    # (trainable parameters) and those that do not (frozen/non trainable parameters).
    requires_grad = ["classifier.1.weight", "classifier.1.bias"]
    frozen_params = [
       param.name
       for param in onnx_model.graph.initializer
       if param.name not in requires_grad
    ]
    
  3. 生成训练工件。

    在本教程中,我们将使用 CrossEntropyLoss 损失函数和 AdamW 优化器。有关工件生成的更多详细信息,请参见此处

    from onnxruntime.training import artifacts
    
    # Generate the training artifacts.
    artifacts.generate_artifacts(
       onnx_model,
       requires_grad=requires_grad,
       frozen_params=frozen_params,
       loss=artifacts.LossType.CrossEntropyLoss,
       optimizer=artifacts.OptimType.AdamW,
       artifact_directory="training_artifacts"
    )
    

    就这样!训练工件已生成到 training_artifacts 文件夹中。这标志着离线阶段的结束。这些工件已准备好部署到 Android 设备进行训练。

训练阶段 - Android 应用程序开发

  1. 在 Android Studio 中设置项目

    a. 打开 Android Studio,然后单击 New Project Android Studio 设置 - 新建项目

    b. 单击 Native C++ -> Next。填写 New Project 详细信息如下

    • 名称 - ORT Personalize
    • 包名 - com.example.ortpersonalize
    • 语言 - Kotlin

    单击 Next

    Android Studio Setup - Project Name

    c. 选择 C++17 工具链 -> Finish

    Android Studio Setup - Project C++ ToolChain

    d. 完成!Android Studio 项目已设置完毕。您现在应该能够看到带有样板代码的 Android Studio 编辑器。

  2. 添加 ONNX Runtime 依赖项

    a. 在 Android Studio 项目的 cpp 目录下创建两个新文件夹:libinclude\onnxruntime

    lib and include folder

    b. 访问 Maven Central。前往 Versions->Browse-> 并下载 onnxruntime-training-android 归档包 (aar 文件)。

    c. 将 aar 扩展名重命名为 zip。因此,onnxruntime-training-android-1.15.0.aar 变为 onnxruntime-training-android-1.15.0.zip

    d. 解压 zip 文件的内容。

    e. 将 jni\arm64-v8a 文件夹中的 libonnxruntime.so 共享库复制到新创建的 lib 文件夹下的 Android 项目中。

    f. 将 headers 文件夹的内容复制到新创建的 include\onnxruntime 文件夹中。

    g. 在 native-lib.cpp 文件中,包含训练 cxx 头文件。

    #include "onnxruntime_training_cxx_api.h"
    

    h. 将 abiFilters 添加到 build.gradle (Module) 文件中,以选择 arm64-v8a。此设置必须添加到 build.gradle 中的 defaultConfig 下。

    ndk {
       abiFilters 'arm64-v8a'
    }
    

    请注意,build.gradle 文件的 defaultConfig 部分应如下所示

     defaultConfig {
        applicationId "com.example.ortpersonalize"
        minSdk 29
        targetSdk 33
        versionCode 1
        versionName "1.0"
    
        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
           cmake {
                 cppFlags '-std=c++17'
           }
        }
    +   ndk {
    +       abiFilters 'arm64-v8a'
    +   }
       
     }
    

    i. 将 onnxruntime 共享库添加到 CMakeLists.txt,以便 cmake 可以找到并针对该共享库进行构建。为此,请在 CMakeLists.txt 中添加 ortpersonalize 库之后添加以下行

    add_library(onnxruntime SHARED IMPORTED)
    set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    

    通过在上述两行之后添加以下行,让 CMake 知道 ONNX Runtime 头文件在哪里可以找到

    target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    

    通过将 onnxruntime 库添加到 target_link_libraries 来将 Android C++ 项目与 onnxruntime 库链接

    target_link_libraries( # Specifies the target library.
         ortpersonalize
    
         # Links the target library to the log library
         # included in the NDK.
         ${log-lib}
    
         onnxruntime)
    

    请注意,CMakeLists.txt 文件应如下所示

    project("ortpersonalize")
    
    add_library( # Sets the name of the library.
          ortpersonalize
    
          # Sets the library as a shared library.
          SHARED
    
          # Provides a relative path to your source file(s).
          native-lib.cpp
    +     utils.cpp
    +     inference.cpp
    +     train.cpp)
    + add_library(onnxruntime SHARED IMPORTED)
    + set_target_properties(onnxruntime PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/lib/libonnxruntime.so)
    + target_include_directories(ortpersonalize PRIVATE ${CMAKE_SOURCE_DIR}/include/onnxruntime)
    
    find_library( # Sets the name of the path variable.
          log-lib
    
          # Specifies the name of the NDK library that
          # you want CMake to locate.
          log)
    
    target_link_libraries( # Specifies the target library.
          ortpersonalize
    
          # Links the target library to the log library
          # included in the NDK.
          ${log-lib}
    +     onnxruntime)
    
    

    j. 构建应用程序并等待成功,以确认应用程序已包含 ONNX Runtime 头文件并能成功链接到共享的 onnxruntime 库。

  3. 打包预构建的训练工件和数据集

    a. 在 Android Studio 项目的左侧面板中,右键单击 app -> New -> Folder -> Assets Folder,并在 main 下创建一个名为 assets 的新文件夹。

    b. 将步骤 2 中生成的训练工件复制到此文件夹。

    c. 现在,前往 onnxruntime-training-examples 仓库,将数据集 (images.zip) 下载到您的机器并解压。此数据集是在 Kaggle 上由 Corrado Alessio 创建的原始 animals-10 数据集的基础上修改的。

    d. 将下载的 images 文件夹复制到 Android Studio 的 assets/images 目录中。

    项目的左侧面板应如下所示

    Project Assets

  4. 与 ONNX Runtime 交互 - C++ 代码

    a. 我们将在 C++ 中实现以下四个函数,这些函数将从应用程序中调用

    • createSession:将在应用程序启动时调用。它将创建新的 CheckpointStateTrainingSession 对象。
    • releaseSession:将在应用程序即将关闭时调用。此函数将释放应用程序启动时分配的资源。
    • performTraining:将在用户单击 UI 上的 Train 按钮时调用。
    • performInference:将在用户单击 UI 上的 Infer 按钮时调用。

    b. 创建会话

    此函数在应用程序启动时调用。它将使用训练工件资源创建 C++ CheckpointState 和 TrainingSession 对象。这些对象将用于在设备上训练模型。

    createSession 的参数是

    • checkpoint_path:检查点工件的缓存路径。
    • train_model_path:训练模型工件的缓存路径。
    • eval_model_path:评估模型工件的缓存路径。
    • optimizer_model_path:优化器模型工件的缓存路径。
    • cache_dir_path:Android 设备上缓存目录的路径。缓存目录用作从 C++ 代码访问训练工件的一种方式。

    该函数返回一个 long,它表示指向 session_cache 对象的指针。每当我们想访问训练会话时,这个 long 都可以强制转换为 SessionCache

    extern "C" JNIEXPORT jlong JNICALL
    Java_com_example_ortpersonalize_MainActivity_createSession(
          JNIEnv *env, jobject /* this */,
          jstring checkpoint_path, jstring train_model_path, jstring eval_model_path,
          jstring optimizer_model_path, jstring cache_dir_path)
    {
       std::unique_ptr<SessionCache> session_cache = std::make_unique<SessionCache>(
                utils::JString2String(env, checkpoint_path),
                utils::JString2String(env, train_model_path),
                utils::JString2String(env, eval_model_path),
                utils::JString2String(env, optimizer_model_path),
                utils::JString2String(env, cache_dir_path));
       return reinterpret_cast<long>(session_cache.release());
    }
    

    从上面的函数体可以看出,这个函数创建了一个 SessionCache 类的对象的唯一指针。SessionCache 的定义如下。

    struct SessionCache {
       ArtifactPaths artifact_paths;
       Ort::Env ort_env;
       Ort::SessionOptions session_options;
       Ort::CheckpointState checkpoint_state;
       Ort::TrainingSession training_session;
       Ort::Session* inference_session;
    
       SessionCache(const std::string &checkpoint_path, const std::string &training_model_path,
                   const std::string &eval_model_path, const std::string &optimizer_model_path,
                   const std::string& cache_dir_path) :
       artifact_paths(checkpoint_path, training_model_path, eval_model_path, optimizer_model_path, cache_dir_path),
       ort_env(ORT_LOGGING_LEVEL_WARNING, "ort personalize"), session_options(),
       checkpoint_state(Ort::CheckpointState::LoadCheckpoint(artifact_paths.checkpoint_path.c_str())),
       training_session(session_options, checkpoint_state, artifact_paths.training_model_path.c_str(),
                         artifact_paths.eval_model_path.c_str(), artifact_paths.optimizer_model_path.c_str()),
       inference_session(nullptr) {}
    };
    

    ArtifactPaths 的定义是

    struct ArtifactPaths {
       std::string checkpoint_path;
       std::string training_model_path;
       std::string eval_model_path;
       std::string optimizer_model_path;
       std::string cache_dir_path;
       std::string inference_model_path;
    
       ArtifactPaths(const std::string &checkpoint_path, const std::string &training_model_path,
                      const std::string &eval_model_path, const std::string &optimizer_model_path,
                      const std::string& cache_dir_path) :
       checkpoint_path(checkpoint_path), training_model_path(training_model_path),
       eval_model_path(eval_model_path), optimizer_model_path(optimizer_model_path),
       cache_dir_path(cache_dir_path), inference_model_path(cache_dir_path + "/inference.onnx") {}
    };
    

    c. 释放会话

    此函数在应用程序即将关闭时调用。它释放应用程序启动时创建的资源,主要是 CheckpointState 和 TrainingSession。

    releaseSession 的参数是

    • session:代表 SessionCache 对象的 long 值。
    extern "C" JNIEXPORT void JNICALL
    Java_com_example_ortpersonalize_MainActivity_releaseSession(
          JNIEnv *env, jobject /* this */,
          jlong session) {
       auto *session_cache = reinterpret_cast<SessionCache *>(session);
       delete session_cache->inference_session;
       delete session_cache;
    }
    

    d. 执行训练

    此函数针对每个需要训练的批次调用。训练循环用 Kotlin 编写在应用程序端,并且在训练循环中,每次批次都会调用 performTraining 函数。

    performTraining 的参数是

    • session:代表 SessionCache 对象的 long 值。
    • batch:作为浮点数组的输入图像,用于训练。
    • labels:与提供用于训练的输入图像关联的整数数组形式的标签。
    • batch_size:每个 TrainStep 要处理的图像数量。
    • channels:图像中的通道数。在我们的示例中,此值始终为 3
    • frame_rows:图像中的行数。在我们的示例中,此值始终为 224
    • frame_cols:图像中的列数。在我们的示例中,此值始终为 224

    该函数返回一个 float,表示该批次的训练损失。

    extern "C"
    JNIEXPORT float JNICALL
    Java_com_example_ortpersonalize_MainActivity_performTraining(
          JNIEnv *env, jobject /* this */,
          jlong session, jfloatArray batch, jintArray labels, jint batch_size,
          jint channels, jint frame_rows, jint frame_cols) {
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
    
       if (session_cache->inference_session) {
          // Invalidate the inference session since we will be updating the model parameters
          // in train_step.
          // The next call to inference session will need to recreate the inference session.
          delete session_cache->inference_session;
          session_cache->inference_session = nullptr;
       }
    
       // Update the model parameters using this batch of inputs.
       return training::train_step(session_cache, env->GetFloatArrayElements(batch, nullptr),
                                  env->GetIntArrayElements(labels, nullptr), batch_size,
                                  channels, frame_rows, frame_cols);
    }
    

    上述函数利用了 train_step 函数。train_step 函数的定义如下

    namespace training {
    
       float train_step(SessionCache* session_cache, float *batches, int32_t *labels,
                         int64_t batch_size, int64_t image_channels, int64_t image_rows,
                         int64_t image_cols) {
          const std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
          const std::vector<int64_t> labels_shape({batch_size});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> user_inputs; // {inputs, labels}
          // Inputs batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, batches,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
          // Labels batched
          user_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels,
                                                             batch_size * sizeof(int32_t),
                                                             labels_shape.data(), labels_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32));
    
          // Run the train step and execute the forward + loss + backward.
          float loss = *(session_cache->training_session.TrainStep(user_inputs).front().GetTensorMutableData<float>());
    
          // Update the model parameters by taking a step in the direction of the gradients computed above.
          session_cache->training_session.OptimizerStep();
    
          // Reset the gradients now that the parameters have been updated.
          // New set of gradients can then be computed for the next round of inputs.
          session_cache->training_session.LazyResetGrad();
    
          return loss;
       }
    
    } // namespace training
    

    e. 执行推理

    当用户想要执行推理时,将调用此函数。

    performInference 的参数是

    • session:代表 SessionCache 对象的 long 值。
    • image_buffer:作为浮点数组的输入图像,用于训练。
    • batch_size:每次推理要处理的图像数量。在我们的示例中,此值始终为 1
    • image_channels:图像中的通道数。在我们的示例中,此值始终为 3
    • image_rows:图像中的行数。在我们的示例中,此值始终为 224
    • image_cols:图像中的列数。在我们的示例中,此值始终为 224
    • classes:表示所有四个自定义类别的字符串列表。

    该函数返回一个 string,表示所提供的四个自定义类别之一。这是模型的预测结果。

    extern "C"
    JNIEXPORT jstring JNICALL
    Java_com_example_ortpersonalize_MainActivity_performInference(
          JNIEnv *env, jobject  /* this */,
          jlong session, jfloatArray image_buffer, jint batch_size, jint image_channels, jint image_rows,
          jint image_cols, jobjectArray classes) {
    
       std::vector<std::string> classes_str;
       for (int i = 0; i < env->GetArrayLength(classes); ++i) {
          // Access the current string element
          jstring elem = static_cast<jstring>(env->GetObjectArrayElement(classes, i));
          classes_str.push_back(utils::JString2String(env, elem));
       }
    
       auto* session_cache = reinterpret_cast<SessionCache *>(session);
       if (!session_cache->inference_session) {
          // The inference session does not exist, so create a new one.
          session_cache->training_session.ExportModelForInferencing(
                   session_cache->artifact_paths.inference_model_path.c_str(), {"output"});
          session_cache->inference_session = std::make_unique<Ort::Session>(
                   session_cache->ort_env, session_cache->artifact_paths.inference_model_path.c_str(),
                   session_cache->session_options).release();
       }
    
       auto prediction = inference::classify(
                session_cache, env->GetFloatArrayElements(image_buffer, nullptr),
                batch_size, image_channels, image_rows, image_cols, classes_str);
    
       return env->NewStringUTF(prediction.first.c_str());
    }
    

    上述函数调用了 classify。classify 的定义是

    namespace inference {
    
       std::pair<std::string, float> classify(SessionCache* session_cache, float *image_data,
                                              int64_t batch_size, int64_t image_channels,
                                              int64_t image_rows, int64_t image_cols,
                                              const std::vector<std::string>& classes) {
          std::vector<const char *> input_names = {"input"};
          size_t input_count = 1;
    
          std::vector<const char *> output_names = {"output"};
          size_t output_count = 1;
    
          std::vector<int64_t> input_shape({batch_size, image_channels, image_rows, image_cols});
    
          Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
          std::vector<Ort::Value> input_values; // {input images}
          input_values.emplace_back(Ort::Value::CreateTensor(memory_info, image_data,
                                                             batch_size * image_channels * image_rows * image_cols * sizeof(float),
                                                             input_shape.data(), input_shape.size(),
                                                             ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT));
    
    
          std::vector<Ort::Value> output_values;
          output_values.emplace_back(nullptr);
    
          // get the logits
          session_cache->inference_session->Run(Ort::RunOptions(), input_names.data(), input_values.data(),
                                                 input_count, output_names.data(), output_values.data(), output_count);
    
          float *output = output_values.front().GetTensorMutableData<float>();
    
          // run softmax and get the probabilities of each class
          std::vector<float> probabilities = Softmax(output, classes.size());
          size_t best_index = std::distance(probabilities.begin(), std::max_element(probabilities.begin(), probabilities.end()));
    
          return {classes[best_index], probabilities[best_index]};
       }
    
    } // namespace inference
    

    classify 函数调用另一个名为 Softmax 的函数。Softmax 的定义是

    std::vector<float> Softmax(float *logits, size_t num_logits) {
       std::vector<float> probabilities(num_logits, 0);
       float sum = 0;
       for (size_t i = 0; i < num_logits; ++i) {
             probabilities[i] = exp(logits[i]);
             sum += probabilities[i];
       }
    
       if (sum != 0.0f) {
             for (size_t i = 0; i < num_logits; ++i) {
                probabilities[i] /= sum;
             }
       }
    
       return probabilities;
    }
    
  5. 图像预处理

    a. MobileNetV2 模型要求输入的图像是

    • 大小为 3 x 224 x 224
    • 一个标准化图像,减去平均值 (0.485, 0.456, 0.406) 并除以标准差 (0.229, 0.224, 0.225)

    此预处理在 Java/Kotlin 中使用 Android 提供的库完成。

    让我们在 app/src/main/java/com/example/ortpersonalize 目录下创建一个名为 ImageProcessingUtil.kt 的新文件。我们将在该文件中添加用于裁剪、调整大小和规范化图像的实用方法。

    b. 裁剪和调整图像大小。

    fun processBitmap(bitmap: Bitmap) : Bitmap {
       // This function processes the given bitmap by
       //   - cropping along the longer dimension to get a square bitmap
       //     If the width is larger than the height
       //     ___+_________________+___
       //     |  +                 +  |
       //     |  +                 +  |
       //     |  +        +        +  |
       //     |  +                 +  |
       //     |__+_________________+__|
       //     <-------- width -------->
       //        <----- height ---->
       //     <-->      cropped    <-->
       //
       //     If the height is larger than the width
       //     _________________________   ʌ            ʌ
       //     |                       |   |         cropped
       //     |+++++++++++++++++++++++|   |      ʌ     v
       //     |                       |   |      |
       //     |                       |   |      |
       //     |           +           | height width
       //     |                       |   |      |
       //     |                       |   |      |
       //     |+++++++++++++++++++++++|   |      v     ʌ
       //     |                       |   |         cropped
       //     |_______________________|   v            v
       //
       //
       //
       //   - resizing the cropped square image to be of size (3 x 224 x 224) as needed by the
       //     mobilenetv2 model.
       lateinit var bitmapCropped: Bitmap
       if (bitmap.getWidth() >= bitmap.getHeight()) {
          // Since height is smaller than the width, we crop a square whose length is the height
          // So cropping happens along the width dimesion
          val width: Int = bitmap.getHeight()
          val height: Int = bitmap.getHeight()
    
          // left side of the cropped image must begin at (bitmap.getWidth() / 2 - bitmap.getHeight() / 2)
          // so that the cropped width contains equal portion of the width on either side of center
          // top side of the cropped image must begin at 0 since we are not cropping along the height
          // dimension
          val x: Int = bitmap.getWidth() / 2 - bitmap.getHeight() / 2
          val y: Int = 0
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       } else {
          // Since width is smaller than the height, we crop a square whose length is the width
          // So cropping happens along the height dimesion
          val width: Int = bitmap.getWidth()
          val height: Int = bitmap.getWidth()
    
          // left side of the cropped image must begin at 0 since we are not cropping along the width
          // dimension
          // top side of the cropped image must begin at (bitmap.getHeight() / 2 - bitmap.getWidth() / 2)
          // so that the cropped height contains equal portion of the height on either side of center
          val x: Int = 0
          val y: Int = bitmap.getHeight() / 2 - bitmap.getWidth() / 2
          bitmapCropped = Bitmap.createBitmap(bitmap, x, y, width, height)
       }
    
       // Resize the image to be channels x width x height as needed by the mobilenetv2 model
       val width: Int = 224
       val height: Int = 224
       val bitmapResized: Bitmap = Bitmap.createScaledBitmap(bitmapCropped, width, height, false)
    
       return bitmapResized
    }
    

    c. 规范化图像。

    fun processImage(bitmap: Bitmap, buffer: FloatBuffer, offset: Int) {
       // This function iterates over the image and performs the following
       // on the image pixels
       //   - normalizes the pixel values to be between 0 and 1
       //   - substracts the mean (0.485, 0.456, 0.406) (derived from the mobilenetv2 model configuration)
       //     from the pixel values
       //   - divides by pixel values by the standard deviation (0.229, 0.224, 0.225) (derived from the
       //     mobilenetv2 model configuration)
       // Values are written to the given buffer starting at the provided offset.
       // Values are written as follows
       // |____|____________________|__________________| <--- buffer
       //      ʌ                                         <--- offset
       //                           ʌ                    <--- offset + width * height * channels
       // |____|rrrrrr|_____________|__________________| <--- red channel read in column major order
       // |____|______|gggggg|______|__________________| <--- green channel read in column major order
       // |____|______|______|bbbbbb|__________________| <--- blue channel read in column major order
    
       val width: Int = bitmap.getWidth()
       val height: Int = bitmap.getHeight()
       val stride: Int = width * height
    
       for (x in 0 until width) {
          for (y in 0 until height) {
                val color: Int = bitmap.getPixel(y, x)
                val index = offset + (x * height + y)
    
                // Subtract the mean and divide by the standard deviation
                // Values for mean and standard deviation used for
                // the movilenetv2 model.
                buffer.put(index + stride * 0, ((Color.red(color).toFloat() / 255f) - 0.485f) / 0.229f)
                buffer.put(index + stride * 1, ((Color.green(color).toFloat() / 255f) - 0.456f) / 0.224f)
                buffer.put(index + stride * 2, ((Color.blue(color).toFloat() / 255f) - 0.406f) / 0.225f)
          }
       }
    }
    

    d. 从 Uri 获取 Bitmap

    fun bitmapFromUri(uri: Uri, contentResolver: ContentResolver): Bitmap {
       // This function reads the image file at the given uri and decodes it to a bitmap
       val source: ImageDecoder.Source = ImageDecoder.createSource(contentResolver, uri)
       return ImageDecoder.decodeBitmap(source).copy(Bitmap.Config.ARGB_8888, true)
    }
    
  6. 应用程序前端

    a. 对于本教程,我们将使用以下用户界面元素

    • 训练和推理按钮
    • 类别按钮
    • 状态消息文本
    • 图像显示
    • 进度对话框

    b. 本教程不打算展示如何创建图形用户界面。因此,我们只需重复使用 GitHub 上提供的文件。

    c. 将 strings.xml 中的所有字符串定义复制到您的 Android Studio 本地的 strings.xml 文件中。

    d. 将 activity_main.xml 的内容复制到您的 Android Studio 本地的 activity_main.xml 文件中。

    e. 在 layout 文件夹下创建一个名为 dialog.xml 的新文件。将 dialog.xml 的内容复制到您 Android Studio 本地新建的 dialog.xml 中。

    f. 本节中剩余的更改需要在 MainActivity.kt 文件中进行。

    g. 启动应用程序

    当应用程序启动时,将调用 onCreate 函数。此函数负责设置会话缓存和用户界面处理程序。

    有关代码,请参阅 MainActivity.kt 文件中的 onCreate 函数。

    h. 自定义类别按钮处理程序 - 我们希望使用类别按钮让用户选择他们的自定义图像进行训练。我们需要为这些按钮添加监听器以实现此目的。这些监听器将精确地完成这项工作。

    请参考 MainActivity.kt 中的这些按钮处理程序

    • onClassAClickedListener
    • onClassBClickedListener
    • onClassXClickedListener
    • onClassYClickedListener

    i. 个性化自定义类别标签

    默认情况下,自定义类别标签是 [A, B, X, Y]。但是,让我们允许用户为清晰起见重命名这些标签。这通过长按监听器实现(在 MainActivity.kt 中定义)

    • onClassALongClickedListener
    • onClassBLongClickedListener
    • onClassXLongClickedListener
    • onClassYLongClickedListener

    j. 切换自定义类别。

    当自定义类别开关关闭时,运行预打包的动物数据集。当开关打开时,用户需要提供自己的数据集进行训练。为了处理这种转换,onCustomClassSettingChangedListener 开关处理程序在 MainActivity.kt 中实现。

    k. 训练处理程序

    当每个类别至少有 1 张图像时,可以启用 Train 按钮。单击 Train 按钮时,将针对选定的图像启动训练。训练处理程序负责

    • 将训练图像收集到一个容器中。
    • 打乱图像的顺序。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 对图像进行批处理。
    • 执行训练循环(循环调用 C++ performTraining 函数)。

    MainActivity.kt 中定义的 onTrainButtonClickedListener 函数完成上述操作。

    l. 推理处理程序

    训练完成后,用户可以点击 Infer 按钮来推理任何图像。推理处理程序负责

    • 收集推理图像。
    • 裁剪和调整图像大小。
    • 归一化图像。
    • 调用 C++ performInference 函数。
    • 向用户界面报告推理输出。

    这通过 MainActivity.kt 中的 onInferenceButtonClickedListener 函数实现。

    m. 上述所有活动的处理程序

    一旦为推理或自定义类别选择了图像,就需要对其进行处理。MainActivity.kt 中定义的 onActivityResult 函数执行此操作。

    n. 最后一件事。在 AndroidManifest.xml 文件中添加以下内容以使用摄像头

    <uses-permission android:name="android.permission.CAMERA" />
    <uses-feature android:name="android.hardware.camera" />
    

训练阶段 - 在设备上运行应用程序

  1. 在设备上运行应用程序

    a. 让我们将 Android 设备连接到机器并在设备上运行应用程序。

    b. 在设备上启动应用程序应该看起来像这样

    Barebones ORT Personalize app

  2. 使用预加载数据集训练 - 动物

    a. 让我们通过在设备上启动应用程序,开始使用预加载的动物数据集进行训练。

    b. 切换底部的 Custom classes 开关。

    c. 类别标签将更改为 DogCatElephantCow

    d. 运行 Training 并等待进度对话框消失(训练完成后)。

    e. 现在使用您图库中的任何动物图像进行推理。

    ORT Personalize app with an image of a cow

    从上图可以看出,模型正确预测为 Cow

  3. 使用自定义数据集训练 - 名人

    a. 从网上下载汤姆·克鲁斯 (Tom Cruise)、莱昂纳多·迪卡普里奥 (Leonardo DiCaprio)、瑞安·雷诺兹 (Ryan Reynolds) 和布拉德·皮特 (Brad Pitt) 的图像。

    b. 确保通过关闭并重新启动应用程序来启动一个新的应用程序会话。

    c. 应用程序启动后,通过长按将四个类别分别重命名为 TomLeoRyanBrad

    d. 点击每个类别的按钮,并选择与该名人相关的图像。每个类别我们可以使用大约 10~15 张图像。

    e. 点击 Train 按钮,让应用程序从提供的数据中学习。

    f. 训练完成后,我们可以点击 Infer 按钮并提供应用程序尚未见过的图像。

    g. 就这样!希望应用程序正确分类了图像。

    an image classification app with Tom Cruise in the middle.

结论

恭喜!您已成功构建了一个 Android 应用程序,该应用程序使用 ONNX Runtime 在设备上学习图像分类。该应用程序也已在 GitHub 上提供,地址为 onnxruntime-training-examples