当前位置: 首页 > news >正文

神经网络模型---ResNet

一、ResNet

1.导入包

import tensorflow as tf
from tensorflow.keras import layers, models, datasets, optimizers

optimizers是用于更新模型参数以最小化损失函数的算法

2.加载数据集、归一化、转为独热编码的内容一致

3.增加颜色通道

train_images = train_images[..., tf.newaxis].astype("float32")
test_images = test_images[..., tf.newaxis].astype("float32")

在train_images和test_images最后一个维度增加一个新的维度
这两行代码还将图像数据转换为浮点数类型

4.定义一个用于图像预处理的模型

4.1创造模型

preprocessing = models.Sequential([

4.2添加一个卷积层,该层有3个1x1的卷积核,激活函数为relu,并且指定了输入形状为28x28像素的单通道图像

layers.Conv2D(3, (1, 1), activation='relu', input_shape=(28, 28, 1)),

4.3 将图像尺寸增加到56x56

    layers.UpSampling2D((2, 2)), 
])

5.应用预处理模型到训练和测试图像上

train_images = preprocessing(train_images)
test_images = preprocessing(test_images)

6.加载ResNet50模型并冻结所有层

base_model=tf.keras.applications.ResNet50(weights='imagenet',include_top=False,input_shape=(56, 56, 3))

ResNet50是一个预训练的卷积神经网络模型,
参数1:加载模型的权重
参数2:是否包括模型顶部的全连接层,设置False意味着不包括这些层,由此可以得到模型的特征提取部分
参数3:输入图像的尺寸
base_model.trainable = False
使ResNet50模型的所有层都不可训练

7.创建模型

model = models.Sequential([

7.1放在模型的第一层添加到序列中,用于提取图像特征

base_model,

7.2在Keras中添加的一个全局平均池化层

layers.GlobalAveragePooling2D(),

7.3在Keras中添加的一个全连接层,使用softmax为激活函数

    layers.Dense(10, activation='softmax')
])

8.编译模型

model.compile(optimizer=optimizers.Adam(),loss='categorical_crossentropy',metrics=['accuracy'])

  • 和上一个博客的模型的内容一样,此处省略

9.训练模型

model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))

结果:
在这里插入图片描述

10.保存文件

model.save('ResNet.h5')

结果:
在这里插入图片描述

相关文章:

  • 一个基于MySQL的数据库课程设计的基本框架
  • 通过防抖动代码解决ResizeObserver loop completed with undelivered notifications.
  • Java基础学习-方法
  • ByteTrack跟踪理解
  • 存储、管理和展示多媒体文件
  • MySQL 连接的使用方法与技巧
  • linux 部署瑞数6实战(维普,药监局)第一部分
  • 前端网站(二)-- 菜单页面【附源码直接可用】
  • chrome 使用本地替换功能替换接口返回内容
  • 基础算法--双指针算法
  • 数据结构历年考研真题对应知识点(单链表、双链表、循环链表)
  • 【机器学习】第11章 神经网络与深度学习(重中之重)
  • 架构师篇-1、总体架构设计
  • 智慧之选:Vatee万腾平台,引领未来的创新引擎
  • hdfs源码解析之DFSClient
  • 分享一款快速APP功能测试工具
  • 【Leetcode】104. 二叉树的最大深度
  • 4. 路由到控制器 - Laravel从零开始教程
  • axios请求、和返回数据拦截,统一请求报错提示_012
  • Python - 闭包Closure
  • Python 使用 Tornado 框架实现 WebHook 自动部署 Git 项目
  • Python学习之路16-使用API
  • SAP云平台里Global Account和Sub Account的关系
  • scrapy学习之路4(itemloder的使用)
  • 初识 beanstalkd
  • 构建工具 - 收藏集 - 掘金
  • 观察者模式实现非直接耦合
  • 基于MaxCompute打造轻盈的人人车移动端数据平台
  • 如何使用 OAuth 2.0 将 LinkedIn 集成入 iOS 应用
  • 学习笔记DL002:AI、机器学习、表示学习、深度学习,第一次大衰退
  • 一个项目push到多个远程Git仓库
  • 一些关于Rust在2019年的思考
  • 摩拜创始人胡玮炜也彻底离开了,共享单车行业还有未来吗? ...
  • ​LeetCode解法汇总2670. 找出不同元素数目差数组
  • ​linux启动进程的方式
  • #鸿蒙生态创新中心#揭幕仪式在深圳湾科技生态园举行
  • (0)Nginx 功能特性
  • (07)Hive——窗口函数详解
  • (13):Silverlight 2 数据与通信之WebRequest
  • (3)医疗图像处理:MRI磁共振成像-快速采集--(杨正汉)
  • (Charles)如何抓取手机http的报文
  • (C语言)输入一个序列,判断是否为奇偶交叉数
  • (SpringBoot)第七章:SpringBoot日志文件
  • (二)linux使用docker容器运行mysql
  • (函数)颠倒字符串顺序(C语言)
  • (九)c52学习之旅-定时器
  • (没学懂,待填坑)【动态规划】数位动态规划
  • (面试必看!)锁策略
  • (企业 / 公司项目)前端使用pingyin-pro将汉字转成拼音
  • (详细文档!)javaswing图书管理系统+mysql数据库
  • (一)WLAN定义和基本架构转
  • (转)MVC3 类型“System.Web.Mvc.ModelClientValidationRule”同时存在
  • (转)视频码率,帧率和分辨率的联系与区别
  • ***检测工具之RKHunter AIDE
  • .net core webapi 部署iis_一键部署VS插件:让.NET开发者更幸福