当前位置: 首页 > news >正文

Tensorflow2.0:CNN、ResNet实现MNIST分类识别

以下仅是个人的学习笔记 ,内容可能是错误

CNN: 

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 导入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1) / 255.0
x_test = x_test.reshape(-1, 28, 28, 1) / 255.0# 构建模型
model = keras.Sequential([layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D(pool_size=(2, 2)),layers.Flatten(),layers.Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)

ResNet18: 

import tensorflow as tf
from keras import layers, models, datasets
import os# 定义gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 指定GPU编号
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:tf.config.experimental.set_memory_growth(gpus[0], True)  # 动态申请显存except RuntimeError as e:print(e)# 加载数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# 数据预处理
train_images, test_images = train_images / 255.0, test_images / 255.0# 搭建残差模块
def resnet_block(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu'):x = layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')(inputs)x = layers.BatchNormalization()(x)if activation:x = layers.Activation(activation)(x)return x# 定义resnet
def resnet18():inputs = layers.Input(shape=(32, 32, 3))num_filters = 64t = layers.BatchNormalization()(inputs)t = resnet_block(t, num_filters=num_filters)for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = resnet_block(t, num_filters=num_filters * 2, strides=2, activation=None)t = layers.Add()([t, resnet_block(t, num_filters=num_filters * 2)])num_filters *= 2for i in range(2):t = resnet_block(t, num_filters=num_filters, activation=None)t = layers.Add()([t, layers.Activation('relu')(t)])t = layers.AveragePooling2D()(t)outputs = layers.Dense(10, activation='softmax')(layers.Flatten()(t))model = models.Model(inputs, outputs)return model# 定义模型
model = resnet18()
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练 CPU
# history = model.fit(train_images, train_labels, epochs=10,
#                     validation_data=(test_images, test_labels))with tf.device('GPU:0'):  # 指定使用GPUhistory = model.fit(train_images, train_labels, epochs=10,validation_data=(test_images, test_labels))

 

相关文章:

  • 宝塔https403默认串站问题解决
  • 【数据结构】树与二叉树(十八):树的存储结构——Father链接结构、儿子链表链接结构
  • C++ 编写动态二维double型数据类Matrix
  • IDEA导入jar包
  • ​软考-高级-系统架构设计师教程(清华第2版)【第9章 软件可靠性基础知识(P320~344)-思维导图】​
  • modbusRTU通信简单实现(使用NModbus4通信库)
  • 【喵叔闲扯】--迪米特法则
  • 23111708[含文档+PPT+源码等]计算机毕业设计基于javaweb的旅游网站前台与后台旅景点
  • 元宇宙3D云展厅应用到汽车销售的方案及特点
  • DAO和增删改查通用方法-BasicDao
  • PON网络应用场景
  • Jupyter Notebook的下载安装与使用教程_Python数据分析与可视化
  • 一文看分布式锁
  • Node.js中的Buffer和Stream
  • CTF-PWN-堆- 【off-by-one】
  • 【译】JS基础算法脚本:字符串结尾
  • [iOS]Core Data浅析一 -- 启用Core Data
  • 【干货分享】SpringCloud微服务架构分布式组件如何共享session对象
  • 10个确保微服务与容器安全的最佳实践
  • CAP 一致性协议及应用解析
  • Consul Config 使用Git做版本控制的实现
  • Date型的使用
  • Docker容器管理
  • go append函数以及写入
  • HTTP中GET与POST的区别 99%的错误认识
  • Java-详解HashMap
  • js学习笔记
  • Laravel 菜鸟晋级之路
  • learning koa2.x
  • python大佬养成计划----difflib模块
  • React组件设计模式(一)
  • Redis 懒删除(lazy free)简史
  • Vue组件定义
  • 今年的LC3大会没了?
  • 可能是历史上最全的CC0版权可以免费商用的图片网站
  • 深入浅出Node.js
  • 使用Maven插件构建SpringBoot项目,生成Docker镜像push到DockerHub上
  • 一、python与pycharm的安装
  • [Shell 脚本] 备份网站文件至OSS服务(纯shell脚本无sdk) ...
  • ​直流电和交流电有什么区别为什么这个时候又要变成直流电呢?交流转换到直流(整流器)直流变交流(逆变器)​
  • # MySQL server 层和存储引擎层是怎么交互数据的?
  • ### Error querying database. Cause: com.mysql.jdbc.exceptions.jdbc4.CommunicationsException
  • #mysql 8.0 踩坑日记
  • #我与Java虚拟机的故事#连载07:我放弃了对JVM的进一步学习
  • (9)YOLO-Pose:使用对象关键点相似性损失增强多人姿态估计的增强版YOLO
  • (Arcgis)Python编程批量将HDF5文件转换为TIFF格式并应用地理转换和投影信息
  • (Matlab)使用竞争神经网络实现数据聚类
  • (Python) SOAP Web Service (HTTP POST)
  • (办公)springboot配置aop处理请求.
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (附源码)计算机毕业设计SSM基于java的云顶博客系统
  • (十六)Flask之蓝图
  • (转)Spring4.2.5+Hibernate4.3.11+Struts1.3.8集成方案一
  • (转)程序员疫苗:代码注入
  • .NET Compact Framework 多线程环境下的UI异步刷新