ORT for Java 入门
ONNX Runtime 提供了一个 Java 绑定,用于在 JVM 上对 ONNX 模型运行推理。
目录
支持的版本
Java 8 或更高版本
构建
发布工件发布到 Maven Central,以便在大多数 Java 构建工具中用作依赖项。 这些工件的构建支持一些流行的平台。
工件 | 描述 | 支持的平台 |
---|---|---|
com.microsoft.onnxruntime:onnxruntime | CPU | Windows x64、Linux x64、macOS x64 |
com.microsoft.onnxruntime:onnxruntime_gpu | GPU (CUDA) | Windows x64、Linux x64 |
对于本地构建,请参阅 Java API 开发文档 以了解更多详细信息。
有关自定义共享库加载机制的信息,请参阅高级加载说明。
API 参考
Javadoc 可在此处获取:here。
示例
示例实现位于 src/test/java/sample/ScoreMNIST.java 中。
编译后,示例代码需要以下参数 ScoreMNIST [path-to-mnist-model] [path-to-mnist] [scikit-learn-flag]
。 MNIST 预计为 libsvm 格式。 如果提供了可选的 scikit-learn 标志,则模型预计由 skl2onnx 生成(因此需要平面特征向量,并生成结构化输出),否则模型预计是由 pytorch 生成的 CNN(需要 [1][1][28][28]
输入,并生成概率向量)。 testdata 中提供了两个示例模型,cnn_mnist_pytorch.onnx
和 lr_mnist_scikit.onnx
。 第一个是使用 PyTorch 训练的 LeNet5 风格 CNN,第二个是使用 scikit-learn 训练的逻辑回归。
单元测试包含加载模型、检查输入/输出节点形状和类型以及构建张量以进行评分的多个示例。
开始使用
这是一个简单的教程,用于开始在给定输入数据的现有 ONNX 模型上运行推理。 该模型通常使用任何知名的训练框架进行训练,并导出到 ONNX 格式。
请注意,下面提供的代码使用了 Java 10 及更高版本中可用的语法。 Java 8 语法类似,但更冗长。
要启动评分会话,首先创建 OrtEnvironment
,然后使用 OrtSession
类打开会话,并将模型的文件路径作为参数传入。
var env = OrtEnvironment.getEnvironment();
var session = env.createSession("model.onnx",new OrtSession.SessionOptions());
创建会话后,您可以使用 OrtSession
对象的 run
方法执行查询。 目前,我们支持 OnnxTensor
输入,模型可以生成 OnnxTensor
、OnnxSequence
或 OnnxMap
输出。 后两者在对 scikit-learn 等框架生成的模型进行评分时更常见。
run 调用需要一个 Map<String,OnnxTensor>
,其中键与模型中存储的输入节点名称匹配。 可以通过在实例化的会话上调用 session.getInputNames()
或 session.getInputInfo()
来查看这些名称。 run 调用生成一个 Result
对象,其中包含一个表示输出的 Map<String,OnnxValue>
。Result
对象是 AutoCloseable
,可以在 try-with-resources 语句中使用,以防止引用泄漏。 一旦 Result
对象关闭,它的所有子 OnnxValue
也都会关闭。
OnnxTensor t1,t2;
var inputs = Map.of("name1",t1,"name2",t2);
try (var results = session.run(inputs)) {
// manipulate the results
}
您可以通过多种方式将输入数据加载到 OnnxTensor 对象中。 最有效的方法是使用 java.nio.Buffer
,但也可以使用多维数组。 如果使用数组构造,则数组不得是不规则的。
FloatBuffer sourceData; // assume your data is loaded into a FloatBuffer
long[] dimensions; // and the dimensions of the input are stored here
var tensorFromBuffer = OnnxTensor.createTensor(env,sourceData,dimensions);
float[][] sourceArray = new float[28][28]; // assume your data is loaded into a float array
var tensorFromArray = OnnxTensor.createTensor(env,sourceArray);
这是一个完整的示例程序,它在预训练的 MNIST 模型上运行推理。
在 GPU 上或使用其他提供程序运行 (可选)
要启用其他执行提供程序(如 GPU),只需在创建 OrtSession 时打开 SessionOptions 上的相应标志即可。
int gpuDeviceId = 0; // The GPU device ID to execute on
var sessionOptions = new OrtSession.SessionOptions();
sessionOptions.addCUDA(gpuDeviceId);
var session = environment.createSession("model.onnx", sessionOptions);
执行提供程序按照它们启用的顺序进行优先级排序。