使用java部署模型训练文件

2023-10-24 23:20

最近ai比较火,有些业务会涉及到调用模型来得到结果,这就会涉及到模型的部署,最常见的是python环境,使用TensorFlow的docker镜像来得到一个模型部署服务。 我们这里讨论的是怎么用java来实现一个最简化的模型部署服务,甚至是用java自建一个多元化的模型部署服务集群(have a joke)。

PMML

这里讨论常见的PMML模型文件和ONNX模型文件,训练完模型后,会使用特定的工具将模型导出为PMML格式或者ONNX格式。 PMML(Predictive Model Markup Language)是一种XML-based的语言,用于表示和共享数据挖掘和统计模型。 PMML文件包含了模型的所有信息,包括模型参数、输入输出变量、数据处理转换等,因此可以在不同的软件和平台上使用。 PMML文件具有较好的可读性和可解释性,也支持多种类型的模型,包括分类、回归、聚类、时间序列等。

ONNX

ONNX(Open Neural Network Exchange)是一种用于表示深度学习模型的开放格式,旨在使不同的深度学习框架能够共享和交换模型。ONNX文件包含了模型的结构、参数和元数据等信息,可以用于在不同的软件和硬件平台上进行推理和训练。 ONNX主要关注于深度学习模型,支持多种主流的深度学习框架,如TensorFlow、PyTorch等。

综上所述,PMML和ONNX文件都是用于表示机器学习模型的文件格式,但它们在具体的应用场景和支持的模型类型上有所不同。也就是说算法同学进行模型训练后只需要要求他们把模型导出为pmml或者onnx来给到java的工程侧,由工程侧来提供模型部署。

java加载调用PMML模型文件:

依赖maven包
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.5.16</version>
</dependency>
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;

    @Test
    public void test2PMML() throws Exception {
        File file = readPmmlFile();

        Evaluator evaluator;
        evaluator = new LoadingModelEvaluatorBuilder()
                .load(file)
                .build();
        evaluator.verify();

        List<InputField> inputFields = evaluator.getInputFields();
        String json1 = JsonUtils.toString(inputFields);
        System.out.println("inputFields: " + json1);

        Map<String, Object> featureMap = new HashMap<>();
        featureMap.put("Sepal.Length", 5.1);
        featureMap.put("Sepal.Width", 3.5);
        featureMap.put("Petal.Length", 1.4);
        featureMap.put("Petal.Width", 0.2);
        featureMap.put("Species", 0);
        featureMap.put("level", 1);
        featureMap.put("level1", 1);
        HashMap<FieldName, FieldValue> predictFeature = new HashMap<>();
        for (InputField inputField : inputFields) {
            String fieldName = inputField.getName().toString();
            System.out.println("inputField: " + fieldName);
            Object value = featureMap.get(fieldName);
            FieldValue inputValue = inputField.prepare(value);
            predictFeature.put(inputField.getName(), inputValue);
        }
        Map<FieldName, ?> results = evaluator.evaluate(predictFeature);
        Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
        // json工具,可以用gson替换
        String json2 = JsonUtils.toString(resultRecord);
        System.out.println("resultRecord: " + json2);
    }

    private File readPmmlFile() throws IOException {
        File file;
        // pmmlFilePath 文件路径按照自己的来,pmml文件 邮件给我可以提供,cdn费用太昂贵,放这里很容易被打爆。
        if (StringUtils.isBlank(pmmlFilePath)) {
            file = File.createTempFile("test2", ".pmml");
            try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test2.pmml")) {
                FileUtils.copyInputStreamToFile(inputStream, file);
            }
            pmmlFilePath = file.getAbsolutePath();
        } else {
            file = new File(pmmlFilePath);
            if (!file.exists()) {
                file = File.createTempFile("test2", ".pmml");
                try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test2.pmml")) {
                    FileUtils.copyInputStreamToFile(inputStream, file);
                }
                pmmlFilePath = file.getAbsolutePath();
            }
        }
        return file;
    }

java加载调用ONNX模型文件

依赖maven包
<dependency>
    <groupId>com.microsoft.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.15.1</version>
</dependency>
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

    @Test
    public void test3TorchRun() throws Exception {
        String modelPath = readOnnxFile().getAbsolutePath();
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        try (OrtSession.SessionOptions options = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, options)) {
            Map<String, OnnxTensor> inputMap = new HashMap<>();
            OnnxTensor a = OnnxTensor.createTensor(env, new float[]{2.0f});
            OnnxTensor b = OnnxTensor.createTensor(env, new float[]{3.0f});
            OnnxTensor c = OnnxTensor.createTensor(env, new float[]{5.0f});
            inputMap.put("a:0", a);
            inputMap.put("b:0", b);
            inputMap.put("c:0", c);

            OrtSession.Result result = session.run(inputMap);
            String json = JsonUtils.toString(result);
            System.out.println(json);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }


    private File readOnnxFile() throws IOException {
        File file;
        // onnxFilePath 文件路径按照自己的来,onnx文件 邮件给我可以提供,cdn费用太昂贵,放这里很容易被打爆。
        if (StringUtils.isBlank(onnxFilePath)) {
            file = File.createTempFile("test3", ".onnx");
            try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test3.onnx")) {
                FileUtils.copyInputStreamToFile(inputStream, file);
            }
            onnxFilePath = file.getAbsolutePath();
        } else {
            file = new File(onnxFilePath);
            if (!file.exists()) {
                file = File.createTempFile("test3", ".onnx");
                try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test3.onnx")) {
                    FileUtils.copyInputStreamToFile(inputStream, file);
                }
                onnxFilePath = file.getAbsolutePath();
            }
        }
        return file;
    }

读取文件的工具类和字符处理工具类可以自己实现,我这里用的是apache的commons-lang3和commons-io包来做的,同理json工具也一样。

以上PMML模型文件和ONNX模型文件的测试文件例子可以邮件call我。

PMML和ONNX的模型文件加载已经可以覆盖大多数模型部署调用场景,对于有一些特殊的模型可以考虑自己实现,由java通过JNI(Java Native Interface)去调用,效率比想象中的要快很多,并不会很慢,正常配置下,单机几十个的qps并发没有问题。

而如果要做到多元化的模型部署服务集群,麻烦的地方在于服务发现(可能初期会通过手动的配置),目前有一个初步的方案可供参考:

类型:工作 标签:PMML,ONNX,AI,JAVA,TensorFlow

我与我周旋久 独孤影 开源实验室