当前位置: 首页 > news >正文

8.1 mnist_soft,TensorFlow构建回归模型

背景

之前已经写了很多TensorFlow的基本知识,现在利用TensorFlow实现一些简单的功能,对原来的知识进行串联,并初步入门,该部分共包括三篇,分别实现的是回归模型,浅层神经网络,KNN。

TensorFlow构建回归模型

本代码的构建步骤

  1. 建立公式的计算图
  2. 损失函数与优化器
  3. 加载数据
  4. 启动会话,训练与测试

建立计算图

在TensorFlow中构建模型,我们首先需要实现的一个计算图,然后再在Session中运行图,并加载数据,那么首先计算图。

公式到计算图的转化

首先假如,我们有公式 e = (a+b) * (b +1)那么我们就可以将其拆解成五个节点

1. a节点,输入节点
2. b节点,输入节点
3. a+b 节点,计算节点,命名为c
4. b+1 节点,计算节点,命名为d
5. e = c * d 计算节点,输出节点,节点e

如图表示就是
计算图

回归模型

同理logits :y = wx+b可以转化为


1.x 输入节点
2.w 权重
3.b 偏执
4. y0 = xw 计算节点
5. y = y0 + b 计算节点,输出节点

回归模型的计算图

如图,这就是我们要实现的计算图,但是在实际的使用过程中却还有两点不同,
1. 第一我们实现模型实际上已经向量化好了的,这是机器学习的基础,这里不再重复,你可以去网易云课堂学习吴恩达教授的深度学习课程,里面有不错的介绍。
2. 在TensorFlow中实现该计算图时,是直接一行代码实现的,并没有再构建y0,实际无影响。详细情况请看代码:

# 定义变量
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 公式
y = tf.matmul(x, W) + b

注意:后面我们使用的时交叉熵回归分类器

损失函数与优化算法

损失函数,我们这里使用的是平均值(平均的是交叉熵分类器的损失)
学习率,设定的为0.5
优化算法,使用的随机梯度下降

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(
  tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

因为TensorFlow已经实现了自动求导,所以我们就不需要想之前用其他Python类库写机器学习代码那样在自己编写反向求导了(老铁,这波稳如狗)

数据的加载

我们这次使用的时mnist的手写数字数据,你也可以使用其他数据来测试这个回归模型,但是注意修改之前的Tensor 的shape

mnist = input_data.read_data_sets("/home/fonttian/CODE/TensoFlow/Data/MNIST_data", one_hot=True)
# 训练数据
batch_xs, batch_ys = mnist.train.next_batch(100)

最后初始化定义session,初始化所有节点,开始训练和测试吧

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# Test trained model
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    if (i+1) % 100 == 0:
        print(i+1,":",sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))

全部代码

import tensorflow as tf
import os

from tensorflow.examples.tutorials.mnist import input_data

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Import data
mnist = input_data.read_data_sets("/home/fonttian/CODE/TensoFlow/Data/MNIST_data", one_hot=True)

# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])

cross_entropy = tf.reduce_mean(
  tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# Train
for i in range(10000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    # Test trained model
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    if (i+1) % 100 == 0:
        print(i+1,":",sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))

参考

【1】https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist_softmax.py

转载于:https://www.cnblogs.com/fonttian/p/9162784.html

相关文章:

  • MQTT服务器搭建--Mosquitto用户名密码配置
  • Kyligence Analytics Platform Enterprise
  • 【转】VS2010/MFC编程入门之二十(常用控件:静态文本框)
  • Shiro:ajax的session超时处理
  • cogs2223 [SDOI2016 Round1] 生成魔咒
  • Sql 时间做条件
  • SQL Server 数据库中的几个常见的临界值
  • A Research Problem UVA - 10837 欧拉函数逆应用
  • 洛谷P2344 奶牛抗议
  • python归档:笔记转化
  • 理解JS中的call、apply、bind方法
  • Number Math
  • 初学JAVA的变量作用域
  • Inno Setup自定义安装界面脚本
  • Spring AOP简单的配置(注解和xml配置)
  • JS中 map, filter, some, every, forEach, for in, for of 用法总结
  • Centos6.8 使用rpm安装mysql5.7
  • java8 Stream Pipelines 浅析
  • JavaScript标准库系列——Math对象和Date对象(二)
  • JAVA多线程机制解析-volatilesynchronized
  • JDK9: 集成 Jshell 和 Maven 项目.
  • JS学习笔记——闭包
  • k个最大的数及变种小结
  • orm2 中文文档 3.1 模型属性
  • PHP 7 修改了什么呢 -- 2
  • Terraform入门 - 1. 安装Terraform
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 关于字符编码你应该知道的事情
  • 前端js -- this指向总结。
  • 前嗅ForeSpider教程:创建模板
  • 微信开源mars源码分析1—上层samples分析
  • ​LeetCode解法汇总2182. 构造限制重复的字符串
  • ​LeetCode解法汇总2583. 二叉树中的第 K 大层和
  • # .NET Framework中使用命名管道进行进程间通信
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • #周末课堂# 【Linux + JVM + Mysql高级性能优化班】(火热报名中~~~)
  • (1)(1.13) SiK无线电高级配置(五)
  • (10)工业界推荐系统-小红书推荐场景及内部实践【排序模型的特征】
  • (2)(2.10) LTM telemetry
  • (2015)JS ES6 必知的十个 特性
  • (三维重建学习)已有位姿放入colmap和3D Gaussian Splatting训练
  • (四)模仿学习-完成后台管理页面查询
  • (一)使用Mybatis实现在student数据库中插入一个学生信息
  • (转)菜鸟学数据库(三)——存储过程
  • *ST京蓝入股力合节能 着力绿色智慧城市服务
  • .net 简单实现MD5
  • /var/log/cvslog 太大
  • @基于大模型的旅游路线推荐方案
  • [2017][note]基于空间交叉相位调制的两个连续波在few layer铋Bi中的全光switch——
  • [C#小技巧]如何捕捉上升沿和下降沿
  • [ERROR]-Error: failure: repodata/filelists.xml.gz from addons: [Errno 256] No more mirrors to try.
  • [Fri 26 Jun 2015 ~ Thu 2 Jul 2015] Deep Learning in arxiv
  • [IE6 only]关于Flash/Flex,返回数据产生流错误Error #2032的解决方式
  • [JS]Math.random()随机数的二三事
  • [LeetCode]—Anagrams 回文构词法