当前位置: 首页 > news >正文

第T11周:优化器对比实验

>- **🍨 本文为[🔗365天深度学习训练营](小团体~第八波) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](K同学啊-CSDN博客)**

一、前期准备工作

1. 设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlibwarnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、导入数据

1. 导入数据 

data_dir    = "./11-data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)batch_size = 16
img_height = 336
img_width  = 336"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names
print(class_names)

 

2. 检查数据 

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

 

3. 配置数据集 

AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化 

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

 

三、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,input_shape=(img_width, img_height, 3),pooling='avg')for layer in vgg16_base_model.layers:layer.trainable = FalseX = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

 

四、训练模型

NO_EPOCHS = 50history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

 

五、评估模型

1. Accuracy与Loss图

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

 

2. 模型评估

def test_accuracy_report(model):score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model2)

 

总结

1. 准确率方面
   - 使用Adam优化器的训练准确率高于使用SGD优化器的训练准确率
   - 使用SGD优化器的验证准确率高于使用Adam优化器的验证准确率
   所以Adam优化器在训练集上拟合得更好,而SGD优化器在未见数据上的泛化能力更强。

2. 损失趋势:
   - 使用Adam优化器的训练损失低于使用SGD优化器的训练损失,这表明Adam在训练过程中更快地收敛。
   - 使用SGD优化器的验证损失低于使用Adam优化器的验证损失,这意味着SGD在防止过拟合方面表现更好。

3. 优化器特性:
   - Adam优化器是一种自适应学习率优化算法,它结合了RMSProp和Momentum两种优化算法的优点,通常在训练初期能够更快地收敛。
   - SGD(随机梯度下降)是一种更传统的优化算法,它在每次迭代中使用整个数据集(或大数据批次)来更新权重,通常在训练后期能够获得更好的泛化性能。

4. 解释:
   - Adam优化器可能在训练集上表现更好,因为它能够更快地调整学习率,从而在训练初期迅速减少损失。
   - SGD可能在验证集上表现更好,因为它的学习率更新不是非常激进,这有助于模型学习到更一般化的特征,而不是过度拟合训练数据。

5. 建议:
   - 如果模型在训练集上的准确率很高,但在验证集上准确率较低,可能需要考虑正则化技术,如Dropout或权重衰减,以减少过拟合。
   - 可以尝试使用学习率衰减策略,随着训练的进行逐渐减小学习率,以提高模型的泛化能力。
   - 还可以尝试不同的优化器参数设置,如学习率、动量等,以找到最佳的训练配置。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 架构设计:负责网络、定时、坐下、站起、重连等,支持多类游戏的无锁房间
  • 通过python提取PDF文件指定页的图片
  • k8s笔记——kubebuilder实战
  • wifiip地址可以随便改吗?wifi的ip地址怎么改变
  • 【计算机网络 - 基础问题】每日 3 题(二)
  • linux: nvidia-smi用法详解
  • 二.Unity中使用虚拟摇杆来控制角色移动
  • Unity 第一人称游戏的武器被其他物体覆盖解决方案
  • 供应RM500UCNAB-D10-SNADA模块
  • leetcode 108.将有序数组转换为二叉搜索树
  • word文档无损原样转pdf在windows平台使用python调用win32com使用pip安装pywin32
  • 嵌入式epoll面试题面试题及参考答案
  • Maven私服Nexus安装及使用
  • 第7篇:【系统分析师】计算机网络
  • openCV的python频率域滤波
  • 【Linux系统编程】快速查找errno错误码信息
  • 【每日笔记】【Go学习笔记】2019-01-10 codis proxy处理流程
  • 4. 路由到控制器 - Laravel从零开始教程
  • electron原来这么简单----打包你的react、VUE桌面应用程序
  • HomeBrew常规使用教程
  • Java IO学习笔记一
  • js如何打印object对象
  • Linux中的硬链接与软链接
  • MySQL Access denied for user 'root'@'localhost' 解决方法
  • php ci框架整合银盛支付
  • Promise面试题2实现异步串行执行
  • Vue实战(四)登录/注册页的实现
  • 仿天猫超市收藏抛物线动画工具库
  • 给自己的博客网站加上酷炫的初音未来音乐游戏?
  • 构建工具 - 收藏集 - 掘金
  • 微信小程序:实现悬浮返回和分享按钮
  • 无服务器化是企业 IT 架构的未来吗?
  • 线性表及其算法(java实现)
  • 小程序滚动组件,左边导航栏与右边内容联动效果实现
  • 移动端解决方案学习记录
  • 运行时添加log4j2的appender
  • 进程与线程(三)——进程/线程间通信
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • #laravel部署安装报错loadFactoriesFrom是undefined method #
  • #职场发展#其他
  • (6)【Python/机器学习/深度学习】Machine-Learning模型与算法应用—使用Adaboost建模及工作环境下的数据分析整理
  • (void) (_x == _y)的作用
  • (二)丶RabbitMQ的六大核心
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • (一)kafka实战——kafka源码编译启动
  • (一)十分简易快速 自己训练样本 opencv级联haar分类器 车牌识别
  • (转)Java socket中关闭IO流后,发生什么事?(以关闭输出流为例) .
  • ***微信公众号支付+微信H5支付+微信扫码支付+小程序支付+APP微信支付解决方案总结...
  • .NET MVC 验证码
  • .NET/C# 编译期间能确定的相同字符串,在运行期间是相同的实例
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地定义和使用弱事件
  • .NetCore发布到IIS
  • .NET版Word处理控件Aspose.words功能演示:在ASP.NET MVC中创建MS Word编辑器
  • .NET单元测试使用AutoFixture按需填充的方法总结
  • .pop ----remove 删除