当前位置: 首页 > news >正文

Tensorflow入门实战 T09进行猫狗识别2

目录

1、前言

2、代码

3、运行结果

4、反思


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、前言

本周学习内容为,采用自己设置的vgg-16网络进行猫狗识别,使用的模型是YOLOv5相关的模块搭建的网络模型。

本周的代码运行,设置了新的运行展示进度条,更换不同风格;其他的模块和之前的很类似,没有很大的改动。主要是学会使用tensorflow完成先关实现。

2、代码

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)import numpy as np
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号import os,PIL,pathlib#隐藏警告
import warnings
warnings.filterwarnings('ignore')data_dir = "./365-9-data"
data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*')))print("图片总数为:",image_count)batch_size = 64
img_height = 224
img_width  = 224"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)breakAUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image,label):return (image/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1)plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropoutdef VGG16(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)# 2nd blockx = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)# 3rd blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)# 4th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)# 5th blockx = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu',  name='fc1')(x)x = Dense(4096, activation='relu', name='fc2')(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=VGG16(1000, (img_width, img_height, 3))
# model.summary()model.compile(optimizer="adam",loss     ='sparse_categorical_crossentropy',metrics  =['accuracy'])from tqdm import tqdm
import tensorflow.keras.backend as Kepochs = 10
lr = 1e-4# 记录训练数据,方便后面的分析
history_train_loss = []
history_train_accuracy = []
history_val_loss = []
history_val_accuracy = []for epoch in range(epochs):train_total = len(train_ds)val_total = len(val_ds)"""total:预期的迭代数目ncols:控制进度条宽度mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)"""with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=1, ncols=100) as pbar:lr = lr * 0.92K.set_value(model.optimizer.lr, lr)train_loss = []train_accuracy = []for image, label in train_ds:"""训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法想详细了解 train_on_batch 的同学,可以看看我的这篇文章:https://www.yuque.com/mingtian-fkmxf/hv4lcq/ztt4gy"""# 这里生成的是每一个batch的acc与losshistory = model.train_on_batch(image, label)train_loss.append(history[0])train_accuracy.append(history[1])pbar.set_postfix({"train_loss": "%.4f" % history[0],"train_acc": "%.4f" % history[1],"lr": K.get_value(model.optimizer.lr)})pbar.update(1)history_train_loss.append(np.mean(train_loss))history_train_accuracy.append(np.mean(train_accuracy))print('开始验证!')with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}', mininterval=0.3, ncols=100) as pbar:val_loss = []val_accuracy = []for image, label in val_ds:# 这里生成的是每一个batch的acc与losshistory = model.test_on_batch(image, label)val_loss.append(history[0])val_accuracy.append(history[1])pbar.set_postfix({"val_loss": "%.4f" % history[0],"val_acc": "%.4f" % history[1]})pbar.update(1)history_val_loss.append(np.mean(val_loss))history_val_accuracy.append(np.mean(val_accuracy))print('结束验证!')print("验证loss为:%.4f" % np.mean(val_loss))print("验证准确率为:%.4f" % np.mean(val_accuracy))epochs_range = range(epochs)plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()import numpy as np# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(1, 8, i + 1)# 显示图片plt.imshow(images[i].numpy())# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0)# 使用模型预测图片中的人物predictions = model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")

3、运行结果

预测结果展示:

4、反思

本周学习内容,让我更加了解了YOLOv5相关的网络模型结构;加深对于YOLOv5模型里面,各类模块的使用,了解如何去搭建和修改YOLO系列的网络模型等;可以将此搭建逻辑应用到自己的网络模型里面,但是一定要确保shape是相互匹配的。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • maven 私服搭建(tar+docker)
  • conda 复现论文部署环境常用操作
  • 缓存弊处的体验:异常
  • NEEP-EN2-2019-Text4
  • 敲详细的springframework-amqp-rabbit源码解析
  • 通信流程:https【SSL/TLS】,git仓库【https/SSH】,蓝牙【面对面快传/AirDrop】
  • 【BUG】已解决:To update, run: python.exe -m pip install --upgrade pip
  • 【学习css3】使用flex和grid实现等高元素布局
  • 插入排序和希尔排序
  • 【后端开发】身份和访问管理IAM(MFA,OTP,JWT,OAuth,SSO)
  • python—爬虫的初步了解
  • 核函数支持向量机(Kernel SVM)
  • IDEA中常用的快捷键
  • 【医学影像】RK3588+FPGA:满足远程诊疗系统8K音视频编解码及高效传输需求
  • SpreadsheetLLM:微软对Excel编码的“摊膀伏”
  • 【剑指offer】让抽象问题具体化
  • 2018天猫双11|这就是阿里云!不止有新技术,更有温暖的社会力量
  • Cookie 在前端中的实践
  • golang 发送GET和POST示例
  • IDEA常用插件整理
  • PHP 程序员也能做的 Java 开发 30分钟使用 netty 轻松打造一个高性能 websocket 服务...
  • react-core-image-upload 一款轻量级图片上传裁剪插件
  • WebSocket使用
  • 百度地图API标注+时间轴组件
  • 关于Flux,Vuex,Redux的思考
  • 爬虫模拟登陆 SegmentFault
  • 前端技术周刊 2019-02-11 Serverless
  • 一起来学SpringBoot | 第十篇:使用Spring Cache集成Redis
  • MPAndroidChart 教程:Y轴 YAxis
  • ​数据结构之初始二叉树(3)
  • #if等命令的学习
  • #LLM入门|Prompt#1.7_文本拓展_Expanding
  • #快捷键# 大学四年我常用的软件快捷键大全,教你成为电脑高手!!
  • $GOPATH/go.mod exists but should not goland
  • (1)SpringCloud 整合Python
  • (iPhone/iPad开发)在UIWebView中自定义菜单栏
  • (ZT)北大教授朱青生给学生的一封信:大学,更是一个科学的保证
  • (二开)Flink 修改源码拓展 SQL 语法
  • (理论篇)httpmoudle和httphandler一览
  • (切换多语言)vantUI+vue-i18n进行国际化配置及新增没有的语言包
  • (三十)Flask之wtforms库【剖析源码上篇】
  • (转)机器学习的数学基础(1)--Dirichlet分布
  • (转载)从 Java 代码到 Java 堆
  • ***监测系统的构建(chkrootkit )
  • .gitignore文件---让git自动忽略指定文件
  • .NET Micro Framework初体验(二)
  • .NET 反射的使用
  • .net操作Excel出错解决
  • .Net面试题4
  • .Net下C#针对Excel开发控件汇总(ClosedXML,EPPlus,NPOI)
  • .pings勒索病毒的威胁:如何应对.pings勒索病毒的突袭?
  • 。Net下Windows服务程序开发疑惑
  • /var/spool/postfix/maildrop 下有大量文件
  • @SuppressLint(NewApi)和@TargetApi()的区别
  • [ CTF ] WriteUp- 2022年第三届“网鼎杯”网络安全大赛(朱雀组)