使用ONNX模型(java)
一、目标
本次探索的目标是探索一种将ONNX模型集成到Java中的方法,以便后期可以在联合仿真环境中加载和执行ONNX模型。
二、研究
什么是ONNX
在进行技术探索之前,我们需要了解ONNX的相关知识。
ONNX(Open Neural Network Exchange)是一种用于表示机器学习模型的开放式格式,可以将模型从一个框架转移到另一个框架。ONNX模型可以使用不同的工具和库进行加载和执行,例如TensorFlow、PyTorch、Caffe2等。在机器学习和人工智能领域,ONNX已成为一个流行的标准格式。由于其开放式和跨平台的特性,ONNX模型可以在不同的环境和设备上使用,例如移动设备、嵌入式系统、云计算平台等。
如下的图来自官方,可以看到有提供了Java的API:
由于ONNX Runtime是跨平台的高性能推理引擎,可以使用ONNX Runtime Java库可以方便地加载和执行ONNX模型。
下面是一个简单的代码示例,展示如何在Java系统中使用ONNX Runtime Java库加载和执行ONNX模型:
import ai.onnxruntime.*; // Load the model and create InferenceSession String modelPath = "path/to/your/onnx/model"; OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession session = env.createSession(modelPath); // Load and preprocess the input image inputTensor ... // Run inference OrtSession.Result outputs = session.run(inputTensor); System.out.println(outputs.get(0).getTensor().getFloatBuffer().get(0));
实现
实现步骤如下:
- 配置ONNX Runtime Java库
- 将ONNX模型加载到系统中
- 设置输入,并验证输出
配置POM
com.microsoft.onnxruntime onnxruntime 1.15.1
相关代码
public class LoadACloopOnnx { private static final String DEFAULT_MODEL= "onnx/simple/simple_model.onnx"; public static void main(String[] args) { try (OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession session = env.createSession(getResource(DEFAULT_MODEL).toString(),new OrtSession.SessionOptions())){ for (String name: session.getInputNames()) { System.out.println("输入: " + session.getInputInfo().get(name)); } for (String name: session.getOutputNames()) { System.out.println("输出: " + session.getOutputInfo().get(name)); } Optional.ofNullable(session.getInputInfo().keySet()) .orElse(Collections.emptySet()) .stream() .findFirst() .ifPresent(key->{ try { NodeInfo nodeInfo = session.getInputInfo().get(key); if (nodeInfo.getInfo() instanceof TensorInfo) { Map stringOnnxTensorMap = new HashMap(); stringOnnxTensorMap.put("input1",OnnxTensor.createTensor(env,new float[]{1})); stringOnnxTensorMap.put("input2",OnnxTensor.createTensor(env,new float[]{2})); try (OrtSession.Result result = session.run(stringOnnxTensorMap)){ for (Map.Entry entry : result) { System.out.println(String.format("结果项[%s]",entry.getKey())); System.out.println("信息:"+entry.getValue().getInfo()); System.out.println("类型:"+entry.getValue().getType()); printMultiArrayHelper(entry.getValue().getValue(),""); } } } } catch (OrtException e) { e.printStackTrace(); } }); } catch (OrtException e) { e.printStackTrace(); } } private static void printMultiArrayHelper(Object array, String indent) { if (array == null) { System.out.println("null"); return; } Class componentType = array.getClass().getComponentType(); if (!componentType.isArray()) { System.out.print(indent); System.out.print("[ "); for (int i = 0; i 0) { System.out.print(", "); } System.out.print(Array.get(array, i)); } System.out.println(" ]"); } else { System.out.println(indent + "["); for (int i = 0; i < Array.getLength(array); i++) { Object subArray = Array.get(array, i); printMultiArrayHelper(subArray, indent + " "); } System.out.println(indent + "]"); } } private static Path getResource(String name) { return Paths.get("src/main/resources").toAbsolutePath().resolve(name); } }
输出结果
输入: NodeInfo(name=input1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])) 输入: NodeInfo(name=input2,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])) 输出: NodeInfo(name=output,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])) 结果项[output] 信息:TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]) 类型:ONNX_TYPE_TENSOR [ 3.0 ]
三、结论
使用ONNX Runtime Java库可以方便地将ONNX模型集成到Java环境中,并与其他子系统进行交互,使用ONNX模型可以方便地在不同的环境和设备上共享和使用。Java系统可以尝试使用该库来加载和执行ONNX模型,并进行集成。
最后,异构模型的支持是大势所趋,联合仿真系统应该积极探索和尝试新的技术和方法,以不断提升系统的性能和功能。ONNX作为一种先进的机器学习模型表示格式,将为联合仿真系统的发展带来新的机遇和挑战,可拓展联合仿真系统的适用范围。
四、参考文档
Get Started with ORT for Java