当前位置: 首页 > news >正文

Tensorflow—第四讲网络八股扩展

本讲概述

一、自制数据集 

我们用六万张数字图片自制训练集,一万张数字图片制作测试集

 

代码(注释已经很清楚了,就不解释了):

def generateds(path, txt):f = open(txt, 'r')  # 以只读形式打开txt文件contents = f.readlines()  # 读取文件中所有行f.close()  # 关闭txt文件x, y_ = [], []  # 建立空列表for content in contents:  # 逐行取出value = content.split()  # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列img_path = path + value[0]  # 拼出图片路径和文件名img = Image.open(img_path)  # 读入图片img = np.array(img.convert('L'))  # 图片变为8位宽灰度值的np.array格式img = img / 255.  # 数据归一化 (实现预处理)x.append(img)  # 归一化后的数据,贴到列表xy_.append(value[1])  # 标签贴到列表y_print('loading : ' + content)  # 打印状态提示x = np.array(x)  # 变为np.array格式y_ = np.array(y_)  # 变为np.array格式y_ = y_.astype(np.int64)  # 变为64位整型return x, y_  # 返回输入特征x,返回标签y_

二、数据增强 

对图像数据的增强,就是对图像进行简单形变,用来应对因拍照角度不同引起的图片变形。

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)  # 给数据增加一个维度,从(60000, 28, 28)reshape为(60000, 28,image_gen_train = ImageDataGenerator(rescale=1. / 1.,  # 如为图像,分母为255时,可归至0~1rotation_range=45,  # 随机45度旋转width_shift_range=.15,  # 宽度偏移height_shift_range=.15,  # 高度偏移horizontal_flip=False,  # 水平翻转zoom_range=0.5  # 将图像随机缩放阈量50%
)
image_gen_train.fit(x_train)

fit时需要4维,所以先给数据增加了一个维度 

跟之前比还有一处改变 :

flow方法通常用于生成批次(batch)数据

三、断点续训

 下次再训练时会加载上次保存的模型

 save_weights_only:是否只保留文件参数;save_best_only:是否只保留最优结果;在fit函数中加入回调选项callbacks返回到history中

 实现代码:

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])

四、参数提取

np.set_printoptions(threshold=np.inf) 这个设置是全局的,会影响到之后所有NumPy数组的打印行为。如果你想恢复默认的打印选项,可以再次调用 np.set_printoptions() 而不传递任何参数。

v.name: 这是变量(权重或偏置)的名称。在模型中,每个变量通常都有一个唯一的名字,这个名字有助于你识别模型中的不同参数。

v.shape: 这是变量的形状。在神经网络中,权重和偏置通常具有特定的形状,这对应于它们在网络中的组织方式。记录形状有助于了解每个参数的维度结构。

v.numpy(): 这是将变量的值转换为NumPy数组。由于深度学习框架(如TensorFlow或PyTorch)中的变量可能是特殊类型的张量,使用.numpy()方法可以将它们的值以NumPy数组的形式提取出来。记录这些值有助于分析或保存模型的当前状态。

实现代码:

np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

五、 acc/loss可视化 

从history 中提取acc,val_acc,loss,val_loss,再用matplotlib画图

实现代码:

acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

六、应用—给图识物

使用predict应用预测 :

实现代码 :

from PIL import Image
import numpy as np
import tensorflow as tfmodel_save_path = './checkpoint/mnist.ckpt'model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])model.load_weights(model_save_path)preNum = int(input("input the number of test pictures:"))for i in range(preNum):image_path = input("the path of test picture:")img = Image.open(image_path)img = img.resize((28, 28), Image.ANTIALIAS)img_arr = np.array(img.convert('L'))for i in range(28):for j in range(28):if img_arr[i][j] < 200:img_arr[i][j] = 255else:img_arr[i][j] = 0img_arr = img_arr / 255.0x_predict = img_arr[tf.newaxis, ...]result = model.predict(x_predict)pred = tf.argmax(result, axis=1)print('\n')tf.print(pred)

 img = img.resize((28, 28), Image.ANTIALIAS):将其大小调整为28x28像素,因为训练的数据输入的图片为28x28像素。Image.ANTIALIAS是一个高级滤波器,用于在缩放过程中平滑图像,减少锯齿效应。

img_arr = np.array(img.convert('L')):将PIL图像对象转换为NumPy数组,并使用convert('L')方法将图像转换为灰度(即单通道)。

for循环:对图像进行阈值处理,将所有像素值小于200的设置为255(白色),大于等于200的设置为0(黑色)。这是一种简单的二值化方法。二值化处理特别适用于处理灰度图像,尤其是当图像是手写数字识别时,这种方法可以帮助模型更容易地区分数字的笔画和背景。

img_arr = img_arr / 255.0:将图像数组的像素值归一化到0到1的范围内,这是许多神经网络模型所期望的输入格式。

x_predict = img_arr[tf.newaxis, ...]:将归一化后的图像数组增加一个维度,从(28, 28)变为(1, 28, 28),以匹配模型的输入要求。

pred = tf.argmax(result, axis=1):使用tf.argmax函数从预测结果中获取最大概率对应的索引,这代表了模型预测的类别。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • [C++]: std::move
  • Vue引入使用iconfont字体图标
  • Java 基础(从编写到运行)详细说明
  • springboot使用WebSocket
  • LeetCode257 二叉树的所有路径
  • 高可用集群KEEPALIVED
  • opencv色彩空间类型转换
  • LLM微调(精讲)-以高考选择题生成模型为例(DataWhale AI夏令营)
  • 前端创作纪念日
  • go语言协程之间的同步
  • 第十章、 异常Exception
  • 东土科技车规级网络芯片获批量应用
  • leetcode300. 最长递增子序列,动态规划附状态转移方程
  • Android 让程序随系统自动启动并允许后台运行(白名单)
  • arch linux 安装Budgie桌面
  • 【跃迁之路】【463天】刻意练习系列222(2018.05.14)
  • centos安装java运行环境jdk+tomcat
  • Laravel核心解读--Facades
  • LintCode 31. partitionArray 数组划分
  • October CMS - 快速入门 9 Images And Galleries
  • RedisSerializer之JdkSerializationRedisSerializer分析
  • SpringCloud集成分布式事务LCN (一)
  • Sublime text 3 3103 注册码
  • Terraform入门 - 1. 安装Terraform
  • Vue实战(四)登录/注册页的实现
  • 阿里云Kubernetes容器服务上体验Knative
  • 安卓应用性能调试和优化经验分享
  • 湖南卫视:中国白领因网络偷菜成当代最寂寞的人?
  • 技术攻略】php设计模式(一):简介及创建型模式
  • 将回调地狱按在地上摩擦的Promise
  • 前嗅ForeSpider中数据浏览界面介绍
  • 让你的分享飞起来——极光推出社会化分享组件
  • 提升用户体验的利器——使用Vue-Occupy实现占位效果
  • 文本多行溢出显示...之最后一行不到行尾的解决
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • 用Python写一份独特的元宵节祝福
  • 自定义函数
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ‌前端列表展示1000条大量数据时,后端通常需要进行一定的处理。‌
  • #define
  • #微信小程序:微信小程序常见的配置传旨
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • (1) caustics\
  • (1)安装hadoop之虚拟机准备(配置IP与主机名)
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (二)pulsar安装在独立的docker中,python测试
  • (二)构建dubbo分布式平台-平台功能导图
  • (二)什么是Vite——Vite 和 Webpack 区别(冷启动)
  • (附源码)springboot猪场管理系统 毕业设计 160901
  • (回溯) LeetCode 78. 子集
  • (力扣)循环队列的实现与详解(C语言)
  • (十)【Jmeter】线程(Threads(Users))之jp@gc - Stepping Thread Group (deprecated)
  • (转)如何上传第三方jar包至Maven私服让maven项目可以使用第三方jar包
  • (自适应手机端)行业协会机构网站模板
  • .gitignore不生效的解决方案