最近ai比较火,有些业务会涉及到调用模型来得到结果,这就会涉及到模型的部署,最常见的是python环境,使用TensorFlow的docker镜像来得到一个模型部署服务。 我们这里讨论的是怎么用java来实现一个最简化的模型部署服务,甚至是用java自建一个多元化的模型部署服务集群(have a joke)。
PMML
这里讨论常见的PMML模型文件和ONNX模型文件,训练完模型后,会使用特定的工具将模型导出为PMML格式或者ONNX格式。 PMML(Predictive Model Markup Language)是一种XML-based的语言,用于表示和共享数据挖掘和统计模型。 PMML文件包含了模型的所有信息,包括模型参数、输入输出变量、数据处理转换等,因此可以在不同的软件和平台上使用。 PMML文件具有较好的可读性和可解释性,也支持多种类型的模型,包括分类、回归、聚类、时间序列等。
ONNX
ONNX(Open Neural Network Exchange)是一种用于表示深度学习模型的开放格式,旨在使不同的深度学习框架能够共享和交换模型。ONNX文件包含了模型的结构、参数和元数据等信息,可以用于在不同的软件和硬件平台上进行推理和训练。 ONNX主要关注于深度学习模型,支持多种主流的深度学习框架,如TensorFlow、PyTorch等。
综上所述,PMML和ONNX文件都是用于表示机器学习模型的文件格式,但它们在具体的应用场景和支持的模型类型上有所不同。也就是说算法同学进行模型训练后只需要要求他们把模型导出为pmml或者onnx来给到java的工程侧,由工程侧来提供模型部署。
java加载调用PMML模型文件:
依赖maven包
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.5.16</version>
</dependency>
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;
@Test
public void test2PMML() throws Exception {
File file = readPmmlFile();
Evaluator evaluator;
evaluator = new LoadingModelEvaluatorBuilder()
.load(file)
.build();
evaluator.verify();
List<InputField> inputFields = evaluator.getInputFields();
String json1 = JsonUtils.toString(inputFields);
System.out.println("inputFields: " + json1);
Map<String, Object> featureMap = new HashMap<>();
featureMap.put("Sepal.Length", 5.1);
featureMap.put("Sepal.Width", 3.5);
featureMap.put("Petal.Length", 1.4);
featureMap.put("Petal.Width", 0.2);
featureMap.put("Species", 0);
featureMap.put("level", 1);
featureMap.put("level1", 1);
HashMap<FieldName, FieldValue> predictFeature = new HashMap<>();
for (InputField inputField : inputFields) {
String fieldName = inputField.getName().toString();
System.out.println("inputField: " + fieldName);
Object value = featureMap.get(fieldName);
FieldValue inputValue = inputField.prepare(value);
predictFeature.put(inputField.getName(), inputValue);
}
Map<FieldName, ?> results = evaluator.evaluate(predictFeature);
Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
// json工具,可以用gson替换
String json2 = JsonUtils.toString(resultRecord);
System.out.println("resultRecord: " + json2);
}
private File readPmmlFile() throws IOException {
File file;
// pmmlFilePath 文件路径按照自己的来,pmml文件 邮件给我可以提供,cdn费用太昂贵,放这里很容易被打爆。
if (StringUtils.isBlank(pmmlFilePath)) {
file = File.createTempFile("test2", ".pmml");
try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test2.pmml")) {
FileUtils.copyInputStreamToFile(inputStream, file);
}
pmmlFilePath = file.getAbsolutePath();
} else {
file = new File(pmmlFilePath);
if (!file.exists()) {
file = File.createTempFile("test2", ".pmml");
try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test2.pmml")) {
FileUtils.copyInputStreamToFile(inputStream, file);
}
pmmlFilePath = file.getAbsolutePath();
}
}
return file;
}
java加载调用ONNX模型文件
依赖maven包
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.1</version>
</dependency>
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
@Test
public void test3TorchRun() throws Exception {
String modelPath = readOnnxFile().getAbsolutePath();
OrtEnvironment env = OrtEnvironment.getEnvironment();
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, options)) {
Map<String, OnnxTensor> inputMap = new HashMap<>();
OnnxTensor a = OnnxTensor.createTensor(env, new float[]{2.0f});
OnnxTensor b = OnnxTensor.createTensor(env, new float[]{3.0f});
OnnxTensor c = OnnxTensor.createTensor(env, new float[]{5.0f});
inputMap.put("a:0", a);
inputMap.put("b:0", b);
inputMap.put("c:0", c);
OrtSession.Result result = session.run(inputMap);
String json = JsonUtils.toString(result);
System.out.println(json);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private File readOnnxFile() throws IOException {
File file;
// onnxFilePath 文件路径按照自己的来,onnx文件 邮件给我可以提供,cdn费用太昂贵,放这里很容易被打爆。
if (StringUtils.isBlank(onnxFilePath)) {
file = File.createTempFile("test3", ".onnx");
try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test3.onnx")) {
FileUtils.copyInputStreamToFile(inputStream, file);
}
onnxFilePath = file.getAbsolutePath();
} else {
file = new File(onnxFilePath);
if (!file.exists()) {
file = File.createTempFile("test3", ".onnx");
try (InputStream inputStream = getClass().getClassLoader().getResourceAsStream("test3.onnx")) {
FileUtils.copyInputStreamToFile(inputStream, file);
}
onnxFilePath = file.getAbsolutePath();
}
}
return file;
}
读取文件的工具类和字符处理工具类可以自己实现,我这里用的是apache的commons-lang3和commons-io包来做的,同理json工具也一样。
以上PMML模型文件和ONNX模型文件的测试文件例子可以邮件call我。
PMML和ONNX的模型文件加载已经可以覆盖大多数模型部署调用场景,对于有一些特殊的模型可以考虑自己实现,由java通过JNI(Java Native Interface)去调用,效率比想象中的要快很多,并不会很慢,正常配置下,单机几十个的qps并发没有问题。
而如果要做到多元化的模型部署服务集群,麻烦的地方在于服务发现(可能初期会通过手动的配置),目前有一个初步的方案可供参考: