当前位置: 首页 > news >正文

TensorFlow系列:第四讲:MobileNetV2实战

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras"""
加载数据集工具类
"""class DatasetLoader:def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):self.path_url = path_urlself.image_size = image_sizeself.batch_size = batch_sizeself.class_mode = class_mode# 不使用图像增强def load_data(self):# 加载训练数据集train_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/train',  # 训练数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签)# 加载验证数据集val_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/validation',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)# 加载测试数据集test_data = keras.preprocessing.image_dataset_from_directory(self.path_url + '/test',  # 验证数据集的目录路径image_size=self.image_size,  # 调整图像大小batch_size=self.batch_size,  # 每批次的样本数量label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签)class_names = train_data.class_namesreturn train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layersfrom utils.dataset_loader import DatasetLoader"""
使用MobileNetV2,实现图像多分类
"""# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16# 训练模型
def train():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()# 构建 MobileNet 模型base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)# 将模型的主干参数进行冻结base_model.trainable = Falsemodel = keras.Sequential([layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),# 设置主干模型base_model,# 对主干模型的输出进行全局平均池化layers.GlobalAveragePooling2D(),# 通过全连接层映射到最后的分类数目上layers.Dense(len(class_total), activation='softmax')])# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 模型结构model.summary()# 指明训练的轮数epoch,开始训练model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)if __name__ == '__main__':train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():# 实例化数据集加载工具类dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()inputs = keras.Input(shape=IMG_SHAPE)# 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)# 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新for layer in base_model.layers:layer.trainable = False# 在 MobileNetV2 之后添加自定义的顶层分类器x = layers.GlobalAveragePooling2D()(base_model.output)predictions = layers.Dense(len(class_total), activation='softmax')(x)# 构建最终模型model = keras.Model(inputs=base_model.input, outputs=predictions)# 编译模型model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 查看模型结构model.summary()model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 测试loss, accuracy = model.evaluate(test_ds)# 输出结果print('Mobilenet test accuracy :', accuracy, ',loss :', loss)# 保存模型 savedModel格式model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)# 保存曲线图Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatchfrom utils.dataset_loader import DatasetLoader"""
模型工具类
"""class ModelUtil:def __init__(self, saved_model_dir, path_url):self.save_model_dir = saved_model_dir  # savedModel 模型保存地址self.path_url = path_url  # 模型训练数据地址# 批量识别 进行可视化显示def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()# 加载savedModel模型tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)# 创建一个新的 Keras 模型,包含 TFSMLayermodel = keras.Sequential([keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状tfs_layer])plt.figure(figsize=(10, 10))for images, labels in test_ds.take(1):# 使用模型进行预测outputs = model.predict(images)for i in range(num_images):plt.subplot(5, 5, i + 1)image = np.array(images[i]).astype("uint8")plt.imshow(image)index = int(np.argmax(outputs[i]))prediction = outputs[i][index]percentage_str = "{:.2f}%".format(prediction * 100)plt.title(f"{class_names[index]}: {percentage_str}")plt.axis("off")plt.subplots_adjust(hspace=0.5, wspace=0.5)plt.show()

使用工具类:

if __name__ == '__main__':# train()model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)model_util.batch_evaluation()

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • SpringBoot:SpringBoot中如何实现对Http接口进行监控
  • Python爬虫:基础爬虫架构及爬取证券之星全站行情数据!
  • spring管理bean源码解析
  • [GICv3] 3. 物理中断处理(Physical Interrupt Handling)
  • 【 香橙派 AIpro评测】烧系统到运行并使用Jupyter Lab 界面体验 AI 应用样例(新手福音)
  • Excel中用VBA实现Outlook发送当前工作簿
  • Qt项目:基于Qt实现的网络聊天室---TCP服务器和token验证
  • 从数据仓库到数据湖(上):数据湖导论
  • 如何通过文件分发系统,实现能源电力企业文件的安全分发流转?
  • 展开说说:Android之View基础知识解析
  • 【Qt 基础】绘图
  • 如何判断服务器是否被攻击
  • 微信小程序如何实现登陆和注册功能?
  • ShardingSphere-JDBC —— 整合 mybatis-plus,调用批量方法执行更新操作扫所有分表问题
  • 【cocos creator】2.4.x实现简单3d功能,点击选中,旋转,材质修改,透明材质
  • (十五)java多线程之并发集合ArrayBlockingQueue
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • flutter的key在widget list的作用以及必要性
  • Git的一些常用操作
  • Python 使用 Tornado 框架实现 WebHook 自动部署 Git 项目
  • 从tcpdump抓包看TCP/IP协议
  • 对象引论
  • 海量大数据大屏分析展示一步到位:DataWorks数据服务+MaxCompute Lightning对接DataV最佳实践...
  • 回顾 Swift 多平台移植进度 #2
  • 技术攻略】php设计模式(一):简介及创建型模式
  • 京东美团研发面经
  • 深入浅出webpack学习(1)--核心概念
  • 腾讯优测优分享 | Android碎片化问题小结——关于闪光灯的那些事儿
  • 体验javascript之美-第五课 匿名函数自执行和闭包是一回事儿吗?
  • 小试R空间处理新库sf
  • 国内唯一,阿里云入选全球区块链云服务报告,领先AWS、Google ...
  • 交换综合实验一
  • ​卜东波研究员:高观点下的少儿计算思维
  • #Datawhale AI夏令营第4期#AIGC文生图方向复盘
  • #数据结构 笔记三
  • (DFS + 剪枝)【洛谷P1731】 [NOI1999] 生日蛋糕
  • (el-Date-Picker)操作(不使用 ts):Element-plus 中 DatePicker 组件的使用及输出想要日期格式需求的解决过程
  • (JS基础)String 类型
  • (poj1.2.1)1970(筛选法模拟)
  • (vue)el-cascader级联选择器按勾选的顺序传值,摆脱层级约束
  • (八)五种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (附源码)python旅游推荐系统 毕业设计 250623
  • (转) Android中ViewStub组件使用
  • (转)h264中avc和flv数据的解析
  • (转)拼包函数及网络封包的异常处理(含代码)
  • * CIL library *(* CIL module *) : error LNK2005: _DllMain@12 already defined in mfcs120u.lib(dllmodu
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .equal()和==的区别 怎样判断字符串为空问题: Illegal invoke-super to void nio.file.AccessDeniedException
  • .Net Core中Quartz的使用方法
  • .NET 实现 NTFS 文件系统的硬链接 mklink /J(Junction)
  • .NET8.0 AOT 经验分享 FreeSql/FreeRedis/FreeScheduler 均已通过测试
  • .NET命名规范和开发约定
  • .NET正则基础之——正则委托
  • @angular/cli项目构建--Dynamic.Form
  • @Bean有哪些属性