当前位置: 首页 > news >正文

Spring Boot集成tensorflow实现图片检测服务

1.什么是tensorflow?

TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,也就是计算图的节点),常见的损失函数,优化算法。

tensorflow

  • TensorFlow 是一个开放源代码软件库,用于进行高性能数值计算。借助其灵活的架构,用户可以轻松地将计算工作部署到多种平台(CPU、GPU、TPU)和设备(桌面设备、服务器集群、移动设备、边缘设备等)。https://www.tensorflow.org/tutorials/?hl=zh-cnwww.tensorflow.org/tutorials/?hl=zh-cn(opens new window)
  • TensorFlow 是一个用于研究和生产的开放源代码机器学习库。TensorFlow 提供了各种 API,可供初学者和专家在桌面、移动、网络和云端环境下进行开发。
  • TensorFlow是采用数据流图(data flow graphs)来计算,所以首先我们得创建一个数据流流图,然后再将我们的数据(数据以张量(tensor)的形式存在)放在数据流图中计算. 节点(Nodes)在图中表示数学操作,图中的边(edges)则表示在节点间相互联系的多维数据数组, 即张量(tensor)。训练模型时tensor会不断的从数据流图中的一个节点flow到另一节点, 这就是TensorFlow名字的由来。 张量(Tensor):张量有多种. 零阶张量为 纯量或标量 (scalar) 也就是一个数值. 比如 [1],一阶张量为 向量 (vector), 比如 一维的 [1, 2, 3],二阶张量为 矩阵 (matrix), 比如 二维的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此类推, 还有 三阶 三维的 … 张量从流图的一端流动到另一端的计算过程。它生动形象地描述了复杂数据结构在人工神经网中的流动、传输、分析和处理模式。

在机器学习中,数值通常由4种类型构成: (1)标量(scalar):即一个数值,它是计算的最小单元,如“1”或“3.2”等。 (2)向量(vector):由一些标量构成的一维数组,如[1, 3.2, 4.6]等。 (3)矩阵(matrix):是由标量构成的二维数组。 (4)张量(tensor):由多维(通常)数组构成的数据集合,可理解为高维矩阵。

tensorflow的基本概念

  • 图:描述了计算过程,Tensorflow用图来表示计算过程
  • 张量:Tensorflow 使用tensor表示数据,每一个tensor是一个多维化的数组
  • 操作:图中的节点为op,一个op获得/输入0个或者多个Tensor,执行并计算,产生0个或多个Tensor
  • 会话:session tensorflow的运行需要再绘话里面运行

tensorflow写代码流程

  • 定义变量占位符
  • 根据数学原理写方程
  • 定义损失函数cost
  • 定义优化梯度下降 GradientDescentOptimizer
  • session 进行训练,for循环
  • 保存saver

2.环境准备

整合步骤

  1. 模型构建:首先,我们需要在TensorFlow中定义并训练深度学习模型。这可能涉及选择合适的网络结构、优化器和损失函数等。
  2. 训练数据准备:接下来,我们需要准备用于训练和验证模型的数据。这可能包括数据清洗、标注和预处理等步骤。
  3. REST API设计:为了与TensorFlow模型进行交互,我们需要在SpringBoot中创建一个REST API。这可以使用SpringBoot的内置功能来实现,例如使用Spring MVC或Spring WebFlux。
  4. 模型部署:在模型训练完成后,我们需要将其部署到SpringBoot应用中。为此,我们可以使用TensorFlow的Java API将模型导出为ONNX或SavedModel格式,然后在SpringBoot应用中加载并使用。

在整合过程中,有几个关键点需要注意。首先,防火墙设置可能会影响TensorFlow训练过程中的网络通信。确保你的防火墙允许TensorFlow访问其所需的网络资源,以免出现训练中断或模型性能下降的问题。其次,要关注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,确保在整合时使用兼容的版本可以避免很多不必要的麻烦。

模型下载

模型构建和模型训练这块设计到python代码,这里跳过,感兴趣的可以下载源代码自己训练模型,咱们直接下载训练好的模型

  • https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz

下载好了,解压放在/resources/inception_v3目录下

3.代码工程

实验目的

实现图片检测

pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><parent><artifactId>springboot-demo</artifactId><groupId>com.et</groupId><version>1.0-SNAPSHOT</version></parent><modelVersion>4.0.0</modelVersion><artifactId>Tensorflow</artifactId><properties><maven.compiler.source>11</maven.compiler.source><maven.compiler.target>11</maven.compiler.target></properties><dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-autoconfigure</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-test</artifactId><scope>test</scope></dependency><dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-platform</artifactId><version>0.5.0</version></dependency><dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependency><dependency><groupId>jmimemagic</groupId><artifactId>jmimemagic</artifactId><version>0.1.2</version></dependency><dependency><groupId>jakarta.platform</groupId><artifactId>jakarta.jakartaee-api</artifactId><version>9.0.0</version></dependency><dependency><groupId>commons-io</groupId><artifactId>commons-io</artifactId><version>2.16.1</version></dependency><dependency><groupId>org.springframework.restdocs</groupId><artifactId>spring-restdocs-mockmvc</artifactId><scope>test</scope></dependency></dependencies>
</project>

controller

package com.et.tf.api;import java.io.IOException;import com.et.tf.service.ClassifyImageService;
import net.sf.jmimemagic.Magic;
import net.sf.jmimemagic.MagicMatch;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;@RestController
@RequestMapping("/api")
public class AppController {@AutowiredClassifyImageService classifyImageService;@PostMapping(value = "/classify")@CrossOrigin(origins = "*")public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {checkImageContents(file);return classifyImageService.classifyImage(file.getBytes());}@RequestMapping(value = "/")public String index() {return "index";}private void checkImageContents(MultipartFile file) {MagicMatch match;try {match = Magic.getMagicMatch(file.getBytes());} catch (Exception e) {throw new RuntimeException(e);}String mimeType = match.getMimeType();if (!mimeType.startsWith("image")) {throw new IllegalArgumentException("Not an image type: " + mimeType);}}}

service

package com.et.tf.service;import jakarta.annotation.PreDestroy;
import java.util.Arrays;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;//Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@Service
@Slf4j
public class ClassifyImageService {private final Session session;private final List<String> labels;private final String outputLayer;private final int W;private final int H;private final float mean;private final float scale;public ClassifyImageService(Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer,@Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,@Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale) {this.labels = labels;this.outputLayer = outputLayer;this.H = imageH;this.W = imageW;this.mean = mean;this.scale = scale;this.session = new Session(inceptionGraph);}public LabelWithProbability classifyImage(byte[] imageBytes) {long start = System.currentTimeMillis();try (Tensor image = normalizedImageToTensor(imageBytes)) {float[] labelProbabilities = classifyImageProbabilities(image);int bestLabelIdx = maxIndex(labelProbabilities);LabelWithProbability labelWithProbability =new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);log.debug(String.format("Image classification [%s %.2f%%] took %d ms",labelWithProbability.getLabel(),labelWithProbability.getProbability(),labelWithProbability.getElapsed()));return labelWithProbability;}}private float[] classifyImageProbabilities(Tensor image) {try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {final Shape resultShape = result.shape();final long[] rShape = resultShape.asArray();if (resultShape.numDimensions() != 2 || rShape[0] != 1) {throw new RuntimeException(String.format("Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",Arrays.toString(rShape)));}int nlabels = (int) rShape[1];FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();float[] dst = new float[nlabels];resultFloatBuffer.read(dst);return dst;}}private int maxIndex(float[] probabilities) {int best = 0;for (int i = 1; i < probabilities.length; ++i) {if (probabilities[i] > probabilities[best]) {best = i;}}return best;}private Tensor normalizedImageToTensor(byte[] imageBytes) {try (Graph g = new Graph();TInt32 batchTensor = TInt32.scalarOf(0);TInt32 sizeTensor = TInt32.vectorOf(H, W);TFloat32 meanTensor = TFloat32.scalarOf(mean);TFloat32 scaleTensor = TFloat32.scalarOf(scale);) {GraphBuilder b = new GraphBuilder(g);//Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image// Some constants specific to the pre-trained model at:// https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz//// - The model was trained with images scaled to 299x299 pixels.// - The colors, represented as R, G, B in 1-byte each were converted to//   float using (value - Mean)/Scale.// Since the graph is being constructed once per execution here, we can use a constant for the// input image. If the graph were to be re-used for multiple input images, a placeholder would// have been more appropriate.final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));final Output output =b.div(b.sub(b.resizeBilinear(b.expandDims(b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),b.constant("make_batch", batchTensor)),b.constant("size", sizeTensor)),b.constant("mean", meanTensor)),b.constant("scale", scaleTensor));try (Session s = new Session(g)) {return s.runner().fetch(output.op().name()).run().get(0);}}}static class GraphBuilder {final Scope scope;GraphBuilder(Graph g) {this.g = g;this.scope = new OpScope(g);}Output div(Output x, Output y) {return binaryOp("Div", x, y);}Output sub(Output x, Output y) {return binaryOp("Sub", x, y);}Output resizeBilinear(Output images, Output size) {return binaryOp("ResizeBilinear", images, size);}Output expandDims(Output input, Output dim) {return binaryOp("ExpandDims", input, dim);}Output cast(Output value, DataType dtype) {return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);}Output decodeJpeg(Output contents, long channels) {return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope).addInput(contents).setAttr("channels", channels).build().output(0);}Output<? extends TType> constant(String name, Tensor t) {return g.opBuilder("Const", name, scope).setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);}private Output binaryOp(String type, Output in1, Output in2) {return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);}private final Graph g;}@PreDestroypublic void close() {session.close();}@Data@NoArgsConstructor@AllArgsConstructorpublic static class LabelWithProbability {private String label;private float probability;private long elapsed;}
}

application.yaml

tf:frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pblabelsPath: inception-v3/imagenet_slim_labels.txtoutputLayer: InceptionV3/Predictions/Reshape_1image:width: 299height: 299mean: 0scale: 255logging.level.net.sf.jmimemagic: WARN
spring:servlet:multipart:max-file-size: 5MB

Application.java

package com.et.tf;import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.tensorflow.Graph;
import org.tensorflow.proto.framework.GraphDef;@SpringBootApplication
@Slf4j
public class Application {public static void main(String[] args) {SpringApplication.run(Application.class, args);}@Beanpublic Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {Resource graphResource = getResource(tfFrozenModelPath);Graph graph = new Graph();graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));log.info("Loaded Tensorflow model");return graph;}private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {Resource graphResource = new FileSystemResource(tfFrozenModelPath);if (!graphResource.exists()) {graphResource = new ClassPathResource(tfFrozenModelPath);}if (!graphResource.exists()) {throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));}return graphResource;}@Beanpublic List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {Resource labelsRes = getResource(labelsPath);log.info("Loaded model labels");return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream().map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());}
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

  • GitHub - Harries/springboot-demo: a simple springboot demo with some components for example: redis,solr,rockmq and so on.

4.测试

启动 Spring Boot应用程序

测试图片分类

访问http://127.0.0.1:8080/,上传一张图片,点击分类

 

5.引用

  • https://www.tensorflow.org/
  • Spring Boot集成tensorflow实现图片检测服务 | Harries Blog™

 

相关文章:

  • 2024广东省职业技能大赛云计算赛项实战——编排部署ERP管理系统
  • 如何设计一个点赞系统
  • Linux系统安装Dify结合内网穿透实现远程访问本地LLM开发平台
  • Redis 数据恢复及持久化策略分析
  • windows系统配置linux环境wsl
  • 深入探索Llama 2:下一代开源语言模型的革新与影响
  • Vue66-vue-默认插槽
  • 01 Shell 编程规范与变量
  • Sklearn之朴素贝叶斯应用
  • IDEA GIt 提交提示 “Contents are identica“
  • ORA-25153 错误处理
  • 乡村振兴的科技创新引领:加强农业科技研发,推广先进适用技术,提高农业生产效率,助力美丽乡村建设
  • CLIP-guided Prototype Modulating for Few-shot Action Recognition
  • Java序列化进阶:Java内置序列化的三种方式
  • python3获取显示器信息并计算出各个显示器是多少寸
  • Android路由框架AnnoRouter:使用Java接口来定义路由跳转
  • Angular 响应式表单 基础例子
  • C++回声服务器_9-epoll边缘触发模式版本服务器
  • Docker容器管理
  • github指令
  • Gradle 5.0 正式版发布
  • iOS 颜色设置看我就够了
  • iOS仿今日头条、壁纸应用、筛选分类、三方微博、颜色填充等源码
  • java 多线程基础, 我觉得还是有必要看看的
  • java8-模拟hadoop
  • JavaScript函数式编程(一)
  • python_bomb----数据类型总结
  • Quartz实现数据同步 | 从0开始构建SpringCloud微服务(3)
  • Spring核心 Bean的高级装配
  • ViewService——一种保证客户端与服务端同步的方法
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • vue.js框架原理浅析
  • 前端技术周刊 2019-01-14:客户端存储
  • 山寨一个 Promise
  • 昨天1024程序员节,我故意写了个死循环~
  • ​2020 年大前端技术趋势解读
  • #快捷键# 大学四年我常用的软件快捷键大全,教你成为电脑高手!!
  • (2022版)一套教程搞定k8s安装到实战 | RBAC
  • (6)【Python/机器学习/深度学习】Machine-Learning模型与算法应用—使用Adaboost建模及工作环境下的数据分析整理
  • (八)c52学习之旅-中断实验
  • (二)延时任务篇——通过redis的key监听,实现延迟任务实战
  • (黑马出品_高级篇_01)SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式
  • (回溯) LeetCode 78. 子集
  • (转)IIS6 ASP 0251超过响应缓冲区限制错误的解决方法
  • (转)setTimeout 和 setInterval 的区别
  • (转)负载均衡,回话保持,cookie
  • .NET 4.0网络开发入门之旅-- 我在“网” 中央(下)
  • .NET Core SkiaSharp 替代 System.Drawing.Common 的一些用法
  • .net 获取某一天 在当月是 第几周 函数
  • .Net 中的反射(动态创建类型实例) - Part.4(转自http://www.tracefact.net/CLR-and-Framework/Reflection-Part4.aspx)...
  • .NetCore+vue3上传图片 Multipart body length limit 16384 exceeded.
  • .NET开源、简单、实用的数据库文档生成工具
  • .NET开源全面方便的第三方登录组件集合 - MrHuo.OAuth
  • .NET中使用Protobuffer 实现序列化和反序列化
  • /etc/sudoers (root权限管理)