在 Java 中使用 MNIST 进行字符识别

这是一个简单的教程,介绍如何开始使用 ONNX 模型对给定输入数据运行推理。模型通常使用任何知名的训练框架进行训练并导出为 ONNX 格式。

请注意,下面提供的代码使用 Java 10 及更高版本可用的语法。Java 8 语法类似但更冗长。要开始评分会话,首先创建 OrtEnvironment,然后使用 OrtSession 类打开一个会话,将模型的文件路径作为参数传入。

    var env = OrtEnvironment.getEnvironment();
    var session = env.createSession("model.onnx",new OrtSession.SessionOptions());

创建会话后,您可以使用 OrtSession 对象的 run 方法执行查询。目前我们支持 OnnxTensor 输入,模型可以生成 OnnxTensorOnnxSequenceOnnxMap 输出。后两种情况在对 Scikit-learn 等框架生成的模型进行评分时更常见。run 调用需要一个 Map<String,OnnxTensor>,其中键与模型中存储的输入节点名称匹配。这些可以通过在实例化会话上调用 session.getInputNames()session.getInputInfo() 来查看。run 调用会生成一个 Result 对象,其中包含一个表示输出的 Map<String,OnnxValue>Result 对象是 AutoCloseable 的,可以在 try-with-resources 语句中使用,以防止引用泄漏。一旦 Result 对象关闭,它的所有子 OnnxValue 也会关闭。

    OnnxTensor t1,t2;
    var inputs = Map.of("name1",t1,"name2",t2);
    try (var results = session.run(inputs)) {
    // manipulate the results
   }

您可以通过多种方式将输入数据加载到 OnnxTensor 对象中。最有效的方法是使用 java.nio.Buffer,但也可以使用多维数组。如果使用数组构造,数组不得是不规则的。

    FloatBuffer sourceData;  // assume your data is loaded into a FloatBuffer
    long[] dimensions;       // and the dimensions of the input are stored here
    var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);

    float[][] sourceArray = new float[28][28];  // assume your data is loaded into a float array 
    var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);

这是一个完整的示例程序,它在预训练的 MNIST 模型上运行推理。

在 GPU 或使用其他提供者上运行(可选)

要启用 GPU 等其他执行提供者,只需在创建 OrtSession 时打开 SessionOptions 上的相应标志即可。

    int gpuDeviceId = 0; // The GPU device ID to execute on
    var sessionOptions = new OrtSession.SessionOptions();
    sessionOptions.addCUDA(gpuDeviceId);
    var session = environment.createSession("model.onnx", sessionOptions);

执行提供者将按其启用的顺序优先选择。