当前位置: 首页 > news >正文

TensorFlow介绍二-线性回归案例

一.案例步骤

1.准备数据集:y=0.8x+0.7  100个样本

2.建立线性模型,初始化w和b变量

3.确定损失函数(预测值与真实值之间的误差),均方误差

4.梯度下降优化损失

二.完整功能代码:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tfdef linear_regression():"""自实现线性回归:return: None"""# 构造数据X为一百行一列X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)# 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据y_true = tf.matmul(X, [[0.8]]) + 0.7# 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)# 构造预测值,使用X乘上更新后的变量w加上by_predict = tf.matmul(X, weights) + bias# 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值error = tf.reduce_mean(tf.square(y_predict - y_true))# 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)# 初始化变量init = tf.global_variables_initializer()with tf.Session() as sess:  # 会话# 运行初始化变量opsess.run(init)# 打印一下初始化的权重和偏置print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))# 开始训练,训练的次数越多越接近真实值for i in range(100):sess.run(optimizer)# 打印每一次更新后的权重,偏置,误差print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))return Noneif __name__ == '__main__':linear_regression()

三.增加其他功能

1.增加命名空间

使代码结构更加清晰,Tensorboard图结构更加清楚,

使用tf.variable_scope方法,里面的名字自己定义

with tf.variable_scope("lr_model"):

2.收集变量

这样更容易观察参数的更新情况 

3.写入事件

使用tensorboard观察,在命令行中切换到事件所在文件目录,使用命令:

tensorboard --logdir="事件所在的文件目录"

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tfdef linear_regression():"""自实现线性回归:return: None"""# 构造数据X为一百行一列with tf.variable_scope("original_data"):  # 表示正在创建数据X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)# 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据y_true = tf.matmul(X, [[0.8]]) + 0.7with tf.variable_scope("linear_model"): # 初始化变量# 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)# 构造预测值,使用X乘上更新后的变量w加上by_predict = tf.matmul(X, weights) + biaswith tf.variable_scope("loss"):  # 确定误差# 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值error = tf.reduce_mean(tf.square(y_predict - y_true))with tf.variable_scope("gd_optimizer"):  # 构建优化器# 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)# 收集变量tf.summary.scalar("error", error)tf.summary.histogram("weights", weights)tf.summary.histogram("bias", bias)# 合并变量merge=tf.summary.merge_all()# 初始化变量init = tf.global_variables_initializer()with tf.Session() as sess:  # 会话# 运行初始化变量opsess.run(init)# 打印一下初始化的权重和偏置print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))# 创建事件文件,将事件写入到ligdir中的目录中file_writer=tf.summary.FileWriter(logdir="./summary",graph=sess.graph)# 开始训练,训练的次数越多越接近真实值for i in range(100):sess.run(optimizer)# 打印每一次更新后的权重,偏置,误差print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))# 运行合并变量opsummary=sess.run(merge)file_writer.add_summary(summary,i)return Noneif __name__ == '__main__':linear_regression()

 四.模型的保存和加载

tf.train.Saver(var_list=None,max_to_keep=5)

保存和加载模型(保存文件格式:checkpoint文件)
var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
max_to_keep:指示要保留的最近检查点文件的最大数量。创建新文件时,会删除较旧的文件。如果无或0,则保留所有检查点文件。默认为5(即保留最新的5个检查点文件。)

例如

# 指定目录+模型名字
# 保存
saver.save(sess, '/tmp/ckpt/test/myregression.ckpt')
# 加载
saver.restore(sess, '/tmp/ckpt/test/myregression.ckpt')

如果判断模型是否存在,直接指定目录

checkpoint = tf.train.latest_checkpoint("./tmp/model/")saver.restore(sess, checkpoint)

五.命令行参数使用

1.tf.app.flags,它支持应用从命令行接收参数,可以用来指定集训配置等,在tf.app.flags下面各种定义参数的类型

2、 tf.app.flags.,在flags有一个FLAGS标志,它在程序中可以调用到我们

前面具体定义的flag_name

3.通过tf.app.run()启动main(argv)函数

# 定义一些常用的命令行参数
# 训练步数
tf.app.flags.DEFINE_integer("max_step", 0, "训练模型的步数")
# 定义模型的路径
tf.app.flags.DEFINE_string("model_dir", " ", "模型保存的路径+模型名字")# 定义获取命令行参数
FLAGS = tf.app.flags.FLAGS# 开启训练
# 训练的步数(依据模型大小而定)
for i in range(FLAGS.max_step):sess.run(train_op)

六.完整代码

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf# 模型保存
tf.app.flags.DEFINE_string("model_path", "./linear_regression/", "模型保存的路径和文件名")
FLAGS = tf.app.flags.FLAGSdef linear_regression():"""自实现线性回归:return: None"""# 构造数据X为一百行一列with tf.variable_scope("original_data"):  # 表示正在创建数据X = tf.random_normal(shape=(100, 1), mean=2, stddev=2)# 真实值,y=x*0.8+0.7,这里X为tf.tensor数据在乘的时候要使用二维数据y_true = tf.matmul(X, [[0.8]]) + 0.7with tf.variable_scope("linear_model"): # 初始化变量# 使用Variable初始化w,b,因为w和b要参与更新所有要使用变量。trainable是设置这个变量是否参与训练weights = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)bias = tf.Variable(initial_value=tf.random_normal(shape=(1, 1)),trainable=True)# 构造预测值,使用X乘上更新后的变量w加上by_predict = tf.matmul(X, weights) + biaswith tf.variable_scope("loss"):  # 确定误差# 计算均方误差,用真实值减去预测值的平方,因为这是一百个数据,使用要求它的平均值error = tf.reduce_mean(tf.square(y_predict - y_true))with tf.variable_scope("gd_optimizer"):  # 构建优化器# 构建优化器,这里使用的是梯度下降优化误差来更新w和b,0.01是学习率optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(error)# 收集变量tf.summary.scalar("error", error)tf.summary.histogram("weights", weights)tf.summary.histogram("bias", bias)# 合并变量merge=tf.summary.merge_all()# 初始化变量init = tf.global_variables_initializer()with tf.Session() as sess:  # 会话# 运行初始化变量opsess.run(init)# 打印一下初始化的权重和偏置print("随机初始化的权重为%f, 偏置为%f" % (weights.eval(), bias.eval()))# 创建事件文件,将事件写入到ligdir中的目录中file_writer=tf.summary.FileWriter(logdir="./summary",graph=sess.graph)# 开始训练,训练的次数越多越接近真实值for i in range(100):sess.run(optimizer)# 打印每一次更新后的权重,偏置,误差print("第%d步的误差为%f,权重为%f, 偏置为%f" % (i, error.eval(), weights.eval(), bias.eval()))# 运行合并变量opsummary=sess.run(merge)file_writer.add_summary(summary,i)return Nonedef main(argv):print("这是main函数")print(argv)print(FLAGS.model_path)linear_regression()if __name__ == '__main__':tf.app.run()

都看到这里了,点个赞呗!!!!!

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【鸿蒙HarmonyOS NEXT】List组件的使用
  • Spring 源码解读:实现Spring容器的启动流程
  • SAP B1 三大基本表单标准功能介绍-物料主数据(下)
  • 嵌入式软件开发学习三:中断
  • 【从问题中去学习k8s】k8s中的常见面试题(夯实理论基础)(二十一)
  • VMware安装windows虚拟机详细过程
  • HTTP 之 Web Sockets 安全策略(十)
  • 大数据-114 Flink DataStreamAPI 程序输入源 自定义输入源 Rich并行源 RichParallelSourceFunction
  • 国际化产品经理的挑战与机遇:跨文化产品管理的探索
  • 大数据新视界--大数据大厂之MySQL 数据库课程设计:数据安全深度剖析与未来展望
  • CentOS全面停服,国产化提速,央国企信创即时通讯/协同门户如何选型?
  • 开源模型应用落地-LangChain高阶-记忆组件-ConversationTokenBufferMemory正确使用(七)
  • 深度学习-OpenCv的运用(4)
  • 群论 (笔记)
  • uniapp常用标签
  • ES6指北【2】—— 箭头函数
  • JavaScript-如何实现克隆(clone)函数
  • CentOS从零开始部署Nodejs项目
  • Java,console输出实时的转向GUI textbox
  • JavaScript-Array类型
  • java第三方包学习之lombok
  • JS实现简单的MVC模式开发小游戏
  • oschina
  • Python进阶细节
  • Three.js 再探 - 写一个跳一跳极简版游戏
  • Vue.js 移动端适配之 vw 解决方案
  • Vue官网教程学习过程中值得记录的一些事情
  • windows下mongoDB的环境配置
  • 笨办法学C 练习34:动态数组
  • 等保2.0 | 几维安全发布等保检测、等保加固专版 加速企业等保合规
  • 第三十一到第三十三天:我是精明的小卖家(一)
  • 记一次用 NodeJs 实现模拟登录的思路
  • 前端js -- this指向总结。
  • 通过几道题目学习二叉搜索树
  • 【干货分享】dos命令大全
  • MyCAT水平分库
  • raise 与 raise ... from 的区别
  • 进程与线程(三)——进程/线程间通信
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • ​Kaggle X光肺炎检测比赛第二名方案解析 | CVPR 2020 Workshop
  • # wps必须要登录激活才能使用吗?
  • #mysql 8.0 踩坑日记
  • #在 README.md 中生成项目目录结构
  • (9)STL算法之逆转旋转
  • (SpringBoot)第七章:SpringBoot日志文件
  • (webRTC、RecordRTC):navigator.mediaDevices undefined
  • (层次遍历)104. 二叉树的最大深度
  • (二)Eureka服务搭建,服务注册,服务发现
  • (官网安装) 基于CentOS 7安装MangoDB和MangoDB Shell
  • (转)程序员技术练级攻略
  • *(长期更新)软考网络工程师学习笔记——Section 22 无线局域网
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)...
  • .NET与 java通用的3DES加密解密方法
  • .NET正则基础之——正则委托
  • .sh 的运行