Java ORT 入门

ONNX Runtime 提供 Java 绑定,用于在 JVM 上对 ONNX 模型进行推理。

目录

支持的版本

Java 8 或更高版本

构建版本

发布工件已发布到 Maven Central,可在大多数 Java 构建工具中用作依赖项。这些工件构建时支持一些流行平台。

Version Shield

工件 描述 支持的平台
com.microsoft.onnxruntime:onnxruntime CPU Windows x64, Linux x64, macOS x64
com.microsoft.onnxruntime:onnxruntime_gpu GPU (CUDA) Windows x64, Linux x64

有关本地构建的更多详细信息,请参阅 Java API 开发文档

有关共享库加载机制的定制,请参阅高级加载说明

API 参考

Javadoc 可在此处获取。

示例

示例实现在 src/test/java/sample/ScoreMNIST.java 中。

编译后,示例代码需要以下参数:ScoreMNIST [mnist模型路径] [mnist路径] [scikit-learn标志]。MNIST 预计采用 libsvm 格式。如果提供了可选的 scikit-learn 标志,模型预计由 skl2onnx 生成(因此需要一个扁平的特征向量,并产生结构化输出),否则模型预计是来自 PyTorch 的 CNN(需要一个 [1][1][28][28] 输入,并产生一个概率向量)。testdata 中提供了两个示例模型:cnn_mnist_pytorch.onnxlr_mnist_scikit.onnx。第一个是使用 PyTorch 训练的 LeNet5 风格 CNN,第二个是使用 scikit-learn 训练的逻辑回归模型。

单元测试包含多个加载模型、检查输入/输出节点形状和类型以及构造用于评分的张量的示例。

入门

这是一个关于如何使用现有 ONNX 模型对给定输入数据运行推理的简单入门教程。该模型通常使用任何知名的训练框架进行训练,并导出为 ONNX 格式。

请注意,下面提供的代码使用了 Java 10 及更高版本支持的语法。Java 8 的语法类似但更冗长。

要开始评分会话,首先创建 OrtEnvironment,然后使用 OrtSession 类打开一个会话,将模型的文件路径作为参数传入。

    var env = OrtEnvironment.getEnvironment();
    var session = env.createSession("model.onnx",new OrtSession.SessionOptions());

创建会话后,可以使用 OrtSession 对象的 run 方法执行查询。目前我们支持 OnnxTensor 输入,模型可以生成 OnnxTensor, OnnxSequenceOnnxMap 输出。后两者在使用 scikit-learn 等框架生成的模型进行评分时更常见。

run 调用需要一个 Map<String,OnnxTensor>,其中键与模型中存储的输入节点名称匹配。可以通过在实例化会话上调用 session.getInputNames()session.getInputInfo() 来查看这些名称。run 调用会生成一个 Result 对象,其中包含一个代表输出的 Map<String,OnnxValue>Result 对象实现了 AutoCloseable 接口,可以在 try-with-resources 语句中使用,以防止引用泄露。一旦 Result 对象关闭,其所有子 OnnxValue 对象也会关闭。

    OnnxTensor t1,t2;
    var inputs = Map.of("name1",t1,"name2",t2);
    try (var results = session.run(inputs)) {
        // manipulate the results
    }

你可以通过几种方式将输入数据加载到 OnnxTensor 对象中。最有效的方法是使用 java.nio.Buffer,但也可以使用多维数组。如果使用数组构造,数组必须不是不规则的。

    FloatBuffer sourceData;  // assume your data is loaded into a FloatBuffer
    long[] dimensions;       // and the dimensions of the input are stored here
    var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);

    float[][] sourceArray = new float[28][28];  // assume your data is loaded into a float array 
    var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);

这是一个完整的示例程序,它对一个预训练的 MNIST 模型运行推理。

在 GPU 或其他提供程序上运行 (可选)

要启用其他执行提供程序(如 GPU),只需在创建 OrtSession 时在 SessionOptions 上打开相应的标志即可。

    int gpuDeviceId = 0; // The GPU device ID to execute on
    var sessionOptions = new OrtSession.SessionOptions();
    sessionOptions.addCUDA(gpuDeviceId);
    var session = environment.createSession("model.onnx", sessionOptions);

执行提供程序按照启用的顺序进行优先级排序。