当前位置: 首页 > news >正文

深度学习优化算法之动量法[公式推导](MXNet)

我们在前面的文章中熟悉了梯度下降的各种形态,深度学习优化算法之(小批量)随机梯度下降(MXNet),也了解了梯度下降的原理,由每次的迭代,梯度下降都根据自变量的当前位置来更新自变量,做自我迭代。但是如果说自变量的迭代方向只是取决于自变量的当前位置的话,这可能会带来一些问题。比如我们来看下函数 y=x_{1}^{2}+2*x_{2}^{2} 的走势,现在我们来看下这个函数其系数为0.1的情况y=0.1*x_{1}^{2}+2*x_{2}^{2} 在学习率变化时,将会发生什么变化。

从eta=0.4一个比较合适的学习率开始:

import d2lzh as d2l
from mxnet import nd

eta=0.4

def f_2d(x1,x2):
    return 0.1*x1**2 + 2*x2**2

def gd_2d(x1,x2,s1,s2):
    return (x1-eta*0.2*x1,x2-eta*4*x2,0,0)

d2l.show_trace_2d(f_2d,d2l.train_2d(gd_2d))
    
#epoch 20, x1 -0.943467, x2 -0.000073

 

图中可以看出,同一个位置上,目标函数在竖直方向(x2轴方向)比在水平方向(x1轴方向)的斜率的绝对值更大,换句话说就是自变量的更新会使自变量在竖直方向比在水平方向移动幅度更大。
我们将学习率调大一点:eta=0.6

我们发现自变量在竖直方向不断越过最优解并逐渐发散了。

动量法

那上面这个问题,我们通过动量法来处理,在前面的文章也有介绍,这里算是一种新的学习与巩固,更重要的是了解为什么动量法能够处理这种上下方向的偏幅。
那很明显上面存在的问题就是自变量在竖直方向的更新不一致,时正时负,找到问题所在之后,那我们就只需要解决这个方向一致的问题就好办了。

对于动量法的推导,我们从指数加权移动平均(Exponentially Weighted  Moving Average)来理解它,还是画图来直观看下其推导过程:

 

然后我们通过代码来看下实际情况:

eta,gamma=0.4,0.5

def f_2d(x1,x2):
    return 0.1*x1**2 + 2*x2**2

#当gamma=0时,就是小批量随机梯度下降
def momentum_gd_2d(x1,x2,v1,v2):
    v1=gamma*v1 + eta*0.2*x1
    v2=gamma*v2 + eta*4*x2
    return x1-v1,x2-v2,v1,v2

d2l.show_trace_2d(f_2d,d2l.train_2d(momentum_gd_2d))

#epoch 20, x1 -0.062843, x2 0.001202

 

图中可以看出使用动量法之后在竖直方向上的移动更加平滑了,而且在水平方向也更快逼近最优解。
然后将学习率调大到0.6,也没有出现发散的情况。

飞机机翼噪音测试

import d2lzh as d2l
from mxnet import nd

#使用飞机噪音数据集来测试
#https://download.csdn.net/download/weixin_41896770/86513479
features,labels=d2l.get_data_ch7()#1503x5,1503

#速度变量用更广义的状态变量states表示
def init_momentum_states():
    v_w=nd.zeros((features.shape[1],1))
    v_b=nd.zeros(1)
    return (v_w,v_b)

def sgd_momentum(params,states,hyperparams):
    for p,v in zip(params,states):
        v[:]=hyperparams['momentum']*v +hyperparams['lr']*p.grad
        p[:]-=v

d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.5},features,labels)
#loss: 0.249161, 0.171031 sec per epoch

#看做特殊的小批量随机梯度下降
#最近2个时间步的2倍小批量梯度的加权平均
#d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.5},features,labels)
#最近10个时间步的10倍小批量梯度的加权平均,1/(1-0.9)
d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.02,'momentum':0.9},features,labels)
loss: 0.259894, 0.177999 sec per epoch

图中可以看出后期的迭代不够平滑,因为10倍小批量梯度比2倍小批量梯度大了5倍,我们将学习率调小5倍试下:

#学习率调下5倍,从0.02到0.004
d2l.train_ch7(sgd_momentum,init_momentum_states(),{'lr':0.004,'momentum':0.9},features,labels)
#loss: 0.243785, 0.181000 sec per epoch

#简洁实现
d2l.train_gluon_ch7('sgd',{'learning_rate':0.004,'momentum':0.9},features,labels)

动量法的出现主要是解决相邻时间步的自变量的在更新方向上的问题,使得它们更加趋向一致,因为它将过去时间步的梯度做了加权平均,而不仅仅是关注当前变量梯度的位置。

相关文章:

  • tomcat面试和Spring的面试题
  • 网课查题公众号接口
  • 基于Hive的搜狗搜索日志与结果Python可视化设计
  • vue+echarts项目四:地区销量趋势(堆叠折线图)
  • SpringSecurity实战-第6-8章
  • 前端 基础知识
  • 【极客时间2】左耳听风
  • 六级高频词汇——Group07
  • C++类和对象详解(中篇)
  • java五位随机验证码的实现。要求前四位是随机大小写的字母,最后一位是数字的组合。例如qWrY4
  • 《关于我摸鱼一天后搞定PyCharm这档事》Python环境配置
  • 公众号网课搜题系统
  • 【C++ Primer Plus】第13章 类继承
  • 中国网络安全专家专业技能水平
  • tomcat-8.5.55 cluster配置session共享实现不停机部署
  • AngularJS指令开发(1)——参数详解
  • CentOS 7 修改主机名
  • Docker 笔记(2):Dockerfile
  • iOS高仿微信项目、阴影圆角渐变色效果、卡片动画、波浪动画、路由框架等源码...
  • java小心机(3)| 浅析finalize()
  • Laravel Telescope:优雅的应用调试工具
  • Node + FFmpeg 实现Canvas动画导出视频
  • vue-router 实现分析
  • Vue--数据传输
  • 阿里研究院入选中国企业智库系统影响力榜
  • 第2章 网络文档
  • 简单实现一个textarea自适应高度
  • 宾利慕尚创始人典藏版国内首秀,2025年前实现全系车型电动化 | 2019上海车展 ...
  • ​Java并发新构件之Exchanger
  • ​水经微图Web1.5.0版即将上线
  • #vue3 实现前端下载excel文件模板功能
  • ( 用例图)定义了系统的功能需求,它是从系统的外部看系统功能,并不描述系统内部对功能的具体实现
  • (4)事件处理——(7)简单事件(Simple events)
  • (附源码)spring boot北京冬奥会志愿者报名系统 毕业设计 150947
  • (六)c52学习之旅-独立按键
  • (三)uboot源码分析
  • (算法设计与分析)第一章算法概述-习题
  • (转)3D模板阴影原理
  • (转)EXC_BREAKPOINT僵尸错误
  • (转载)OpenStack Hacker养成指南
  • (总结)Linux下的暴力密码在线破解工具Hydra详解
  • *p++,*(p++),*++p,(*p)++区别?
  • . ./ bash dash source 这五种执行shell脚本方式 区别
  • .cn根服务器被攻击之后
  • .net CHARTING图表控件下载地址
  • .pings勒索病毒的威胁:如何应对.pings勒索病毒的突袭?
  • /etc/skel 目录作用
  • [Android]使用Git将项目提交到GitHub
  • [APIO2012] 派遣 dispatching
  • [AutoSAR系列] 1.3 AutoSar 架构
  • [BUUCTF 2018]Online Tool
  • [C]编译和预处理详解
  • [codeforces]Levko and Permutation
  • [flume$2]记录一个写自定义Flume拦截器遇到的错误
  • [Flutter]WindowsPlatform上运行遇到的问题总结