使用 ONNX Runtime 在 Java 中进行 MNIST 字符识别
这是一个简单的教程,介绍如何在给定输入数据上使用现有 ONNX 模型运行推理。该模型通常使用任何知名的训练框架进行训练,并导出为 ONNX 格式。
请注意,下面提供的代码使用了 Java 10 及更高版本可用的语法。Java 8 语法类似但更冗长。要开始评分会话,首先创建 OrtEnvironment
,然后使用 OrtSession
类打开一个会话,将模型的文件路径作为参数传入。
var env = OrtEnvironment.getEnvironment();
var session = env.createSession("model.onnx",new OrtSession.SessionOptions());
创建会话后,您可以使用 OrtSession
对象的 run
方法执行查询。目前我们支持 OnnxTensor
输入,模型可以产生 OnnxTensor
、OnnxSequence
或 OnnxMap
输出。后两者在使用 scikit-learn 等框架生成的模型进行评分时更常见。run
调用需要一个 Map<String,OnnxTensor>
,其中键与模型中存储的输入节点名称匹配。可以通过在实例化的会话上调用 session.getInputNames()
或 session.getInputInfo()
来查看这些名称。run
调用会产生一个 Result
对象,其中包含一个表示输出的 Map<String,OnnxValue>
。Result
对象是 AutoCloseable
的,可以在 try-with-resources 语句中使用,以防止引用泄漏。Result
对象关闭后,其所有子 OnnxValue
也会关闭。
OnnxTensor t1,t2;
var inputs = Map.of("name1",t1,"name2",t2);
try (var results = session.run(inputs)) {
// manipulate the results
}
您可以通过几种方式将输入数据加载到 OnnxTensor 对象中。最有效的方法是使用 java.nio.Buffer
,但也可以使用多维数组。如果使用数组构造,则数组必须不是锯齿状的。
FloatBuffer sourceData; // assume your data is loaded into a FloatBuffer
long[] dimensions; // and the dimensions of the input are stored here
var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);
float[][] sourceArray = new float[28][28]; // assume your data is loaded into a float array
var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);
这里有一个完整的示例程序,它在预训练的 MNIST 模型上运行推理。
在 GPU 或其他提供者上运行 (可选)
要启用其他执行提供者(如 GPU),只需在创建 OrtSession 时打开 SessionOptions 上相应的标志即可。
int gpuDeviceId = 0; // The GPU device ID to execute on
var sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(gpuDeviceId);
var session = environment.createSession("model.onnx", sessionOptions);
执行提供者会按照启用的顺序优先使用。