使用ONNX模型(java)

2023年 8月 14日 37.0k 0

一、目标

本次探索的目标是探索一种将ONNX模型集成到Java中的方法,以便后期可以在联合仿真环境中加载和执行ONNX模型。

二、研究

什么是ONNX

在进行技术探索之前,我们需要了解ONNX的相关知识。

ONNX(Open Neural Network Exchange)是一种用于表示机器学习模型的开放式格式,可以将模型从一个框架转移到另一个框架。ONNX模型可以使用不同的工具和库进行加载和执行,例如TensorFlow、PyTorch、Caffe2等。在机器学习和人工智能领域,ONNX已成为一个流行的标准格式。由于其开放式和跨平台的特性,ONNX模型可以在不同的环境和设备上使用,例如移动设备、嵌入式系统、云计算平台等。

如下的图来自官方,可以看到有提供了Java的API:
image.png

由于ONNX Runtime是跨平台的高性能推理引擎,可以使用ONNX Runtime Java库可以方便地加载和执行ONNX模型。

下面是一个简单的代码示例,展示如何在Java系统中使用ONNX Runtime Java库加载和执行ONNX模型:

import ai.onnxruntime.*;

// Load the model and create InferenceSession
String modelPath = "path/to/your/onnx/model";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath);

// Load and preprocess the input image inputTensor
...

// Run inference
OrtSession.Result outputs = session.run(inputTensor);
System.out.println(outputs.get(0).getTensor().getFloatBuffer().get(0));

实现

实现步骤如下:

  • 配置ONNX Runtime Java库
  • 将ONNX模型加载到系统中
  • 设置输入,并验证输出

配置POM


    com.microsoft.onnxruntime
    onnxruntime
    1.15.1

相关代码

public class LoadACloopOnnx {
    private static final String DEFAULT_MODEL= "onnx/simple/simple_model.onnx";

    public static void main(String[] args) {
        try (OrtEnvironment env = OrtEnvironment.getEnvironment();
             OrtSession session = env.createSession(getResource(DEFAULT_MODEL).toString(),new OrtSession.SessionOptions())){
            for (String name: session.getInputNames()) {
                System.out.println("输入: " + session.getInputInfo().get(name));
            }
            for (String name: session.getOutputNames()) {
                System.out.println("输出: " + session.getOutputInfo().get(name));
            }
        
            Optional.ofNullable(session.getInputInfo().keySet())
                    .orElse(Collections.emptySet())
                    .stream()
                    .findFirst()
                    .ifPresent(key->{
                        try {
                            NodeInfo nodeInfo = session.getInputInfo().get(key);
                            if (nodeInfo.getInfo() instanceof  TensorInfo) {
                                Map stringOnnxTensorMap = new HashMap();
                                stringOnnxTensorMap.put("input1",OnnxTensor.createTensor(env,new float[]{1}));
                                stringOnnxTensorMap.put("input2",OnnxTensor.createTensor(env,new float[]{2}));
                                try (OrtSession.Result result = session.run(stringOnnxTensorMap)){
                                    for (Map.Entry entry : result) {
                                        System.out.println(String.format("结果项[%s]",entry.getKey()));
                                        System.out.println("信息:"+entry.getValue().getInfo());
                                        System.out.println("类型:"+entry.getValue().getType());
                                        printMultiArrayHelper(entry.getValue().getValue(),"");
                                    }
                                }
                            }
                        } catch (OrtException e) {
                            e.printStackTrace();
                        }
                    });
        } catch (OrtException e) {
            e.printStackTrace();
        }
    }



    private static void printMultiArrayHelper(Object array, String indent) {
        if (array == null) {
            System.out.println("null");
            return;
        }

        Class componentType = array.getClass().getComponentType();
        if (!componentType.isArray()) {
            System.out.print(indent);
            System.out.print("[ ");
            for (int i = 0; i  0) {
                    System.out.print(", ");
                }
                System.out.print(Array.get(array, i));
            }
            System.out.println(" ]");
        } else {
            System.out.println(indent + "[");
            for (int i = 0; i < Array.getLength(array); i++) {
                Object subArray = Array.get(array, i);
                printMultiArrayHelper(subArray, indent + "  ");
            }
            System.out.println(indent + "]");
        }
    }


    private static Path getResource(String name) {
        return Paths.get("src/main/resources").toAbsolutePath().resolve(name);
    }
}

输出结果

输入: NodeInfo(name=input1,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输入: NodeInfo(name=input2,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
输出: NodeInfo(name=output,info=TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1]))
结果项[output]
信息:TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1])
类型:ONNX_TYPE_TENSOR
[ 3.0 ]

三、结论

使用ONNX Runtime Java库可以方便地将ONNX模型集成到Java环境中,并与其他子系统进行交互,使用ONNX模型可以方便地在不同的环境和设备上共享和使用。Java系统可以尝试使用该库来加载和执行ONNX模型,并进行集成。

最后,异构模型的支持是大势所趋,联合仿真系统应该积极探索和尝试新的技术和方法,以不断提升系统的性能和功能。ONNX作为一种先进的机器学习模型表示格式,将为联合仿真系统的发展带来新的机遇和挑战,可拓展联合仿真系统的适用范围。

四、参考文档

Get Started with ORT for Java

相关文章

JavaScript2024新功能:Object.groupBy、正则表达式v标志
PHP trim 函数对多字节字符的使用和限制
新函数 json_validate() 、randomizer 类扩展…20 个PHP 8.3 新特性全面解析
使用HTMX为WordPress增效:如何在不使用复杂框架的情况下增强平台功能
为React 19做准备:WordPress 6.6用户指南
如何删除WordPress中的所有评论

发布评论