当前位置: 首页 > news >正文

【tensorflow框架神经网络实现鸢尾花分类】

文章目录

  • 1、数据获取
  • 2、数据集构建
  • 3、模型的训练验证
  • 可视化训练过程

1、数据获取

  • 从sklearn中获取鸢尾花数据,并合并处理
from sklearn.datasets import load_iris
import pandas as pdx_data = load_iris().data
y_data = load_iris().targetx_data = pd.DataFrame(x_data, columns=['花萼长度','花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width', True)x_data['类别'] = y_data
x_data

在这里插入图片描述

2、数据集构建

  • 数据集构建包括:
    • 数据读取
    • 数据打乱
    • 数据划分
    • 小批量迭代器生成
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris# 1、从sklearn包中datasets读取数据集
x_data = load_iris().data
y_data = load_iris().target# 2、数据打乱
np.random.seed(1)   # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(1)
np.random.shuffle(y_data)
tf.random.set_seed(1)# 3、训练集、测试集划分
x_train, x_test = x_data[:-30], x_data[-30:]
y_train, y_test = y_data[:-30], y_data[-30:]# 4、小批量数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

3、模型的训练验证

# 定义超参数,预设变量
lr = 0.1
loss_all = 0
Epoch = 500
train_loss_list = []
test_acc = []# 定义神经网络的可训练参数
w = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))# 循环迭代,训练参数
for epoch in range(Epoch):for step, (x_, y_) in enumerate(train_db):with tf.GradientTape() as tape:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_lab = tf.one_hot(y_, depth=3)loss = tf.reduce_mean(tf.square(y_lab - y_pre))loss_all += loss.numpy()grads = tape.gradient(loss, [w,b])w.assign_sub(lr * grads[0])b.assign_sub(lr * grads[1])print(f'Epoch: {epoch}, loss: {loss_all/4}')train_loss_list.append(loss_all/4)loss_all = 0# 测试部分total_correct, total_number = 0, 0for x_,y_ in test_db:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_p = tf.argmax(y_pre, axis=1)y_p = tf.cast(y_p, dtype=y_.dtype)correct = tf.cast(tf.equal(y_p, y_), dtype=tf.int32)correct = tf.reduce_sum(correct)total_correct += int(correct) total_number += x_.shape[0]acc = total_correct / total_numbertest_acc.append(acc)print("Test_acc:", acc)print("-"*30)

在这里插入图片描述

可视化训练过程

# 绘制测试Acc曲线和训练loss曲线
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(train_loss_list,'b-')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')ax1 = ax.twinx()
ax1.plot(test_acc,'r-')
ax1.set_ylabel('Acc')ax1.spines['left'].set_color('blue')
ax1.spines['right'].set_color('red')

在这里插入图片描述

相关文章:

  • LeetCode6. Z 字形变换(Java)
  • 基于Echarts的超市销售可视化分析系统(数据+程序+论文)
  • fastadmin学习01-windows下安装部署
  • Flink基于Hudi维表Join缺陷解析及解决方案
  • JimuReport积木报表 v1.7.4 公测版本发布,免费的JAVA报表工具
  • Vivado Lab Edition
  • LabVIEW电动汽车直流充电桩监控系统
  • 全方位保障企业远控安全,贝锐向日葵首发远程办公安全白皮书
  • day69实现MyBatis 的Mapper接口 封装SqlSession对象 mapper接口形参怎么给占位符赋值
  • Knative 助力 XTransfer 加速应用云原生 Serverless 化
  • OpenCV的图像颜色空间转换、缩放、裁剪与旋转
  • 葵花卫星影像应用场景及数据获取
  • 机器学习优化算法(深度学习)
  • AI短视频制作一本通:文本生成视频、图片生成视频、视频生成视频
  • 十一、Spring源码学习之registerListeners方法
  • @angular/forms 源码解析之双向绑定
  • [译]如何构建服务器端web组件,为何要构建?
  • 《深入 React 技术栈》
  • Apache Spark Streaming 使用实例
  • JavaScript中的对象个人分享
  • Java程序员幽默爆笑锦集
  • js面向对象
  • Map集合、散列表、红黑树介绍
  • nginx 负载服务器优化
  • node-glob通配符
  • Redis 懒删除(lazy free)简史
  • redis学习笔记(三):列表、集合、有序集合
  • WebSocket使用
  • 番外篇1:在Windows环境下安装JDK
  • 给Prometheus造假数据的方法
  • 关于Android中设置闹钟的相对比较完善的解决方案
  • 简单易用的leetcode开发测试工具(npm)
  • 如何用vue打造一个移动端音乐播放器
  • 使用 QuickBI 搭建酷炫可视化分析
  • 数据库写操作弃用“SELECT ... FOR UPDATE”解决方案
  • 小程序测试方案初探
  • 怎样选择前端框架
  • 正则与JS中的正则
  • MyCAT水平分库
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​io --- 处理流的核心工具​
  • ​TypeScript都不会用,也敢说会前端?
  • $.each()与$(selector).each()
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (2022 CVPR) Unbiased Teacher v2
  • (6)添加vue-cookie
  • (C语言)共用体union的用法举例
  • (Demo分享)利用原生JavaScript-随机数-实现做一个烟花案例
  • (黑马C++)L06 重载与继承
  • (转)mysql使用Navicat 导出和导入数据库
  • .【机器学习】隐马尔可夫模型(Hidden Markov Model,HMM)
  • .htaccess配置常用技巧
  • .mat 文件的加载与创建 矩阵变图像? ∈ Matlab 使用笔记
  • .NET 设计模式—适配器模式(Adapter Pattern)
  • .NET命令行(CLI)常用命令