当前位置: 首页 > news >正文

【Tensorflow1.0+】记录常用函数

tf.train.ExponentialMovingAverage(decay, steps)

tf.train.ExponentialMovingAverage这个函数用于更新参数,就是采用滑动平均的方法更新参数。这个函数初始化需要提供一个衰减速率(decay),用于控制模型的更新速度。这个函数还会维护一个影子变量(也就是更新参数后的参数值),这个影子变量的初始值就是这个变量的初始值,影子变量值的更新方式如下:

shadow_variable = decay * shadow_variable + (1-decay) * variable

shadow_variable是影子变量,variable表示待更新的变量,也就是变量被赋予的值,decay为衰减速率。decay一般设为接近于1的数(0.99,0.999)。decay越大模型越稳定,因为decay越大,参数更新的速度就越慢,趋于稳定。

tf.train.ExponentialMovingAverage这个函数还提供了自己动更新decay的计算方式:

decay= min(decay,(1+steps)/(10+steps))

steps是迭代的次数,可以自己设定。

比如:

[python] view plain copy
  1. import tensorflow as tf;  
  2. import numpy as np;  
  3. import matplotlib.pyplot as plt;  
  4.   
  5. v1 = tf.Variable(0, dtype=tf.float32)  
  6. step = tf.Variable(tf.constant(0))  
  7.   
  8. ema = tf.train.ExponentialMovingAverage(0.99, step)  
  9. maintain_average = ema.apply([v1])  
  10.   
  11. with tf.Session() as sess:  
  12.     init = tf.initialize_all_variables()  
  13.     sess.run(init)  
  14.   
  15.     print sess.run([v1, ema.average(v1)]) #初始的值都为0  
  16.   
  17.     sess.run(tf.assign(v1, 5)) #把v1变为5  
  18.     sess.run(maintain_average)  
  19.     print sess.run([v1, ema.average(v1)]) # decay=min(0.99, 1/10)=0.1, v1=0.1*0+0.9*5=4.5  
  20.   
  21.     sess.run(tf.assign(step, 10000)) # steps=10000  
  22.     sess.run(tf.assign(v1, 10)) # v1=10  
  23.     sess.run(maintain_average)  
  24.     print sess.run([v1, ema.average(v1)]) # decay=min(0.99,(1+10000)/(10+10000))=0.99, v1=0.99*4.5+0.01*10=4.555  
  25.   
  26.     sess.run(maintain_average)  
  27.     print sess.run([v1, ema.average(v1)]) #decay=min(0.99,<span style="font-family:Arial, Helvetica, sans-serif;">(1+10000)/(10+10000)</span><span style="font-family:Arial, Helvetica, sans-serif;">)=0.99, v1=0.99*4.555+0.01*10=4.6</span>  
输出:

[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.5549998]
[10.0, 4.6094499]

解释:每次更新完以后,影子变量的值更新,varible的值就是你设定的值。如果在下一次运行这个函数的时候你不在指定新的值,那就不变,影子变量更新。如果指定,那就variable改变,影子变量也改变。

tf.trainable_variables:返回的是需要训练的变量列表

tf.all_variables:返回的是所有变量的列表

例如:

[python] view plain copy
  1. import tensorflow as tf;    
  2. import numpy as np;    
  3. import matplotlib.pyplot as plt;    
  4.   
  5. v = tf.Variable(tf.constant(0.0, shape=[1], dtype=tf.float32), name='v')  
  6. v1 = tf.Variable(tf.constant(5, shape=[1], dtype=tf.float32), name='v1')  
  7.   
  8. global_step = tf.Variable(tf.constant(5, shape=[1], dtype=tf.float32), name='global_step', trainable=False)  
  9. ema = tf.train.ExponentialMovingAverage(0.99, global_step)  
  10.   
  11. for ele1 in tf.trainable_variables():  
  12.     print ele1.name  
  13. for ele2 in tf.all_variables():  
  14.     print ele2.name  
输出:

v:0
v1:0

v:0
v1:0
global_step:0

分析:上面得到两个变量,后面的一个得到上三个变量,因为 global_step在声明的时候说明不是训练变量,用来关键字 trainable=False。

tf.control_dependencies()函数用法:

在有些机器学习程序中我们想要指定某些操作执行的依赖关系,这时我们可以使用tf.control_dependencies()来实现。
control_dependencies(control_inputs)返回一个控制依赖的上下文管理器,使用with关键字可以让在这个上下文环境中的操作都在control_inputs 执行。

with g.control_dependencies([a, b, c]):
  # `d` and `e` will only run after `a`, `b`, and `c` have executed.
  d = ...
  e = ...
  • 1
  • 2
  • 3
  • 4

可以嵌套control_dependencies 使用

with g.control_dependencies([a, b]):
  # Ops constructed here run after `a` and `b`.
  with g.control_dependencies([c, d]):
    # Ops constructed here run after `a`, `b`, `c`, and `d`.
  • 1
  • 2
  • 3
  • 4

可以传入None 来消除依赖:

with g.control_dependencies([a, b]):
  # Ops constructed here run after `a` and `b`.
  with g.control_dependencies(None):
    # Ops constructed here run normally, not waiting for either `a` or `b`.
    with g.control_dependencies([c, d]):
      # Ops constructed here run after `c` and `d`, also not waiting
      # for either `a` or `b`.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

注意
控制依赖只对那些在上下文环境中建立的操作有效,仅仅在context中使用一个操作或张量是没用的

# WRONG
def my_func(pred, tensor):
  t = tf.matmul(tensor, tensor)
  with tf.control_dependencies([pred]):
    # The matmul op is created outside the context, so no control
    # dependency will be added.
    return t

# RIGHT
def my_func(pred, tensor):
  with tf.control_dependencies([pred]):
    # The matmul op is created in the context, so a control dependency
    # will be added.
    return tf.matmul(tensor, tensor)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

例子:
在训练模型时我们每步训练可能要执行两种操作,op a, b 这时我们就可以使用如下代码:

with tf.control_dependencies([a, b]):
    c= tf.no_op(name='train')#tf.no_op;什么也不做
sess.run(c)
  • 1
  • 2
  • 3

在这样简单的要求下,可以将上面代码替换为:

c= tf.group([a, b])
sess.run(c)

相关文章:

  • 【零基础入门学习Python笔记017】GUI的最终选择:Tkinter
  • 用Yolov2模型训练VOC数据集的各文件理解
  • python 中easydict的简单使用
  • Numpy np.array 相关常用操作学习笔记
  • Matconvnet关于simplenn 转dagnn的一些小的总结
  • 海思AI芯片(Hi3519A/3559A)方案学习(一)资料以及术语介绍
  • 海思AI芯片(Hi3519A/3559A)方案学习(二)RuyiStudio安装
  • 海思AI芯片(Hi3519A/3559A)方案学习(三)Ubuntu18.0.4上编译Hi3519AV100 uboot和kernel
  • 海思AI芯片(Hi3519A/3559A)方案学习(四)如何在3519A板子上运行sample code
  • 海思AI芯片(Hi3519A/3559A)方案学习(五)SDK平台文档梳理
  • Atlas 200 DK 系列 -- 快速搭建开发环境
  • Atlas 200 DK 系列--初级篇--MindStudio常见操作
  • Linux/Ubuntu下解压命令
  • ubuntu终端命令的几个常用重要命令
  • 海思AI芯片(35xx):window仿真代码需修改部分
  • CentOS 7 修改主机名
  • create-react-app做的留言板
  • Fastjson的基本使用方法大全
  • React-生命周期杂记
  • Vue 2.3、2.4 知识点小结
  • vue 配置sass、scss全局变量
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • win10下安装mysql5.7
  • 高程读书笔记 第六章 面向对象程序设计
  • 记一次删除Git记录中的大文件的过程
  • 前端存储 - localStorage
  • 学习Vue.js的五个小例子
  • 移动端唤起键盘时取消position:fixed定位
  • 《码出高效》学习笔记与书中错误记录
  • ​业务双活的数据切换思路设计(下)
  • (k8s中)docker netty OOM问题记录
  • (pojstep1.1.1)poj 1298(直叙式模拟)
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • .NET Core实战项目之CMS 第十二章 开发篇-Dapper封装CURD及仓储代码生成器实现
  • .NET Core中Emit的使用
  • .Net Framework 4.x 程序到底运行在哪个 CLR 版本之上
  • .NET 除了用 Task 之外,如何自己写一个可以 await 的对象?
  • .NET 使用 XPath 来读写 XML 文件
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地中转一个自定义的弱事件(可让任意 CLR 事件成为弱事件)
  • .net和php怎么连接,php和apache之间如何连接
  • .NET框架
  • /proc/vmstat 详解
  • @PreAuthorize注解
  • @开发者,一文搞懂什么是 C# 计时器!
  • [ C++ ] STL_vector -- 迭代器失效问题
  • [20150321]索引空块的问题.txt
  • [AIGC codze] Kafka 的 rebalance 机制
  • [android] 手机卫士黑名单功能(ListView优化)
  • [Bada开发]初步入口函数介绍
  • [JavaWeb]—Spring入门
  • [Python从零到壹] 六十三.图像识别及经典案例篇之图像漫水填充分割应用
  • [pytorch] 2. tensorboard
  • [raspberry pi3] zram设置
  • [Redis]——数据一致性,先操作数据库,还是先更新缓存?
  • [RK-Linux] 移植Linux-5.10到RK3399(四)| 检查HDMI配置与打开内核LOGO显示