当前位置: 首页 > news >正文

神经网络优化算法如何选择Adam,SGD

之前在tensorflow上和caffe上都折腾过CNN用来做视频处理,在学习tensorflow例子的时候代码里面给的优化方案默认很多情况下都是直接用的AdamOptimizer优化算法,如下:

optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)
  • 1

但是在使用caffe时solver里面一般都用的SGD+momentum,如下:

base_lr: 0.0001
momentum: 0.9
weight_decay: 0.0005 lr_policy: "step"
  • 1
  • 2
  • 3
  • 4

加上最近看了一篇文章:The Marginal Value of Adaptive Gradient Methods 
in Machine Learning文章链接,文中也探讨了在自适应优化算法:AdaGrad, RMSProp, and Adam和SGD算法性能之间的比较和选择,因此在此搬一下结论和感想。

Abstract

经过本文的实验,得出最重要的结论是:

We observe that the solutions found by adaptive methods generalize worse (often significantly worse) than SGD, even when these solutions have better training performance. These
results suggest that practitioners should reconsider the use of adaptive methods to train neural networks
  • 1
  • 2
  • 3

翻译一下就是自适应优化算法通常都会得到比SGD算法性能更差(经常是差很多)的结果,尽管自适应优化算法在训练时会表现的比较好,因此使用者在使用自适应优化算法时需要慎重考虑!(终于知道为啥CVPR的paper全都用的SGD了,而不是用理论上最diao的Adam)

Introduction

作者继续给了干货结论: 
Our experiments reveal three primary findings.

First,
with the same amount of hyperparameter tuning, SGD and SGD with momentum outperform adaptive methods on the development/test set across all evaluated models and tasks. This is true even when the adaptive methods achieve the same training loss or lower than non-adaptive methods. Second, adaptive methods often display faster initial progress on the training set, but their performance quickly plateaus on the development/test set. Third, the same amount of tuning was required for all methods, including adaptive methods. This challenges the conventional wisdom that adaptive methods require less tuning. 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

翻译: 
1:用相同数量的超参数来调参,SGD和SGD +momentum 方法性能在测试集上的额误差好于所有的自适应优化算法,尽管有时自适应优化算法在训练集上的loss更小,但是他们在测试集上的loss却依然比SGD方法高, 
2:自适应优化算法 在训练前期阶段在训练集上收敛的更快,但是在测试集上这种有点遇到了瓶颈。 
3:所有方法需要的迭代次数相同,这就和约定俗成的默认自适应优化算法 需要更少的迭代次数的结论相悖!

Conclusion

贴几张作者做的实验结果图: 
这里写图片描述

可以看到SGD在训练前期loss下降并不是最快的,但是在test set上的Perplexity 困惑度(这里写链接内容)是最小的。

在tensorflow中使用SGD算法:(参考)

    # global_step
    training_iters=len(data_config['train_label'])
    global_step=training_iters*model_config['n_epoch']
    decay_steps=training_iters*1
    #global_step = tf.Variable(0, name = 'global_step', trainable=False) lr=tf.train.exponential_decay(learning_rate=model_config['learning_rate'], global_step=global_step, decay_steps=decay_steps, decay_rate=0.1, staircase=False, name=None) optimizer=tf.train.GradientDescentOptimizer(lr).minimize(cost,var_list=network.all_params)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u014381600/article/details/72867109

转载于:https://www.cnblogs.com/Ph-one/p/9400333.html

相关文章:

  • tf.nn.relu
  • tf.nn.max_pool
  • 【TensorFlow】tf.nn.max_pool实现池化操作
  • git博客好的例子
  • 桌面版Ubuntu系统固定IP设置和Network-manager设置
  • ubuntu----VMware 鼠标自由切换问题及主机虚拟机共享剪切板问题
  • markdownpad2-注册码-2017-02-23
  • zynq基础--linux下软件应用
  • tftp 传输文件
  • TensorFlow模型保存和加载方法
  • OpenCV常用库函数[典]
  • https://blog.csdn.net/dayancn/article/details/54692111
  • 用C++调用tensorflow在python下训练好的模型(centos7)
  • 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
  • TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式
  • axios请求、和返回数据拦截,统一请求报错提示_012
  • CSS 专业技巧
  • Java的Interrupt与线程中断
  • Laravel Mix运行时关于es2015报错解决方案
  • PHP 的 SAPI 是个什么东西
  • SpiderData 2019年2月16日 DApp数据排行榜
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 阿里中间件开源组件:Sentinel 0.2.0正式发布
  • 初识 beanstalkd
  • 基于 Babel 的 npm 包最小化设置
  • 基于Android乐音识别(2)
  • 写给高年级小学生看的《Bash 指南》
  • 一道面试题引发的“血案”
  • raise 与 raise ... from 的区别
  • 阿里云移动端播放器高级功能介绍
  • 哈罗单车融资几十亿元,蚂蚁金服与春华资本加持 ...
  • 扩展资源服务器解决oauth2 性能瓶颈
  • ​VRRP 虚拟路由冗余协议(华为)
  • ​插件化DPI在商用WIFI中的价值
  • (2.2w字)前端单元测试之Jest详解篇
  • (9)YOLO-Pose:使用对象关键点相似性损失增强多人姿态估计的增强版YOLO
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第5节(封闭类和Final方法)
  • (echarts)echarts使用时重新加载数据之前的数据存留在图上的问题
  • (超详细)2-YOLOV5改进-添加SimAM注意力机制
  • (二)构建dubbo分布式平台-平台功能导图
  • (附源码)ssm考试题库管理系统 毕业设计 069043
  • (附源码)基于SSM多源异构数据关联技术构建智能校园-计算机毕设 64366
  • (附源码)计算机毕业设计ssm本地美食推荐平台
  • (机器学习的矩阵)(向量、矩阵与多元线性回归)
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (转)Scala的“=”符号简介
  • (转载)hibernate缓存
  • .bat批处理(一):@echo off
  • .NET WebClient 类下载部分文件会错误?可能是解压缩的锅
  • .NET 依赖注入和配置系统
  • .w文件怎么转成html文件,使用pandoc进行Word与Markdown文件转化
  • @JSONField或@JsonProperty注解使用
  • @select 怎么写存储过程_你知道select语句和update语句分别是怎么执行的吗?
  • @TableLogic注解说明,以及对增删改查的影响
  • [ Algorithm ] N次方算法 N Square 动态规划解决