当前位置: 首页 > news >正文

模型优化学习笔记—动量梯度下降

一、mini-batch

梯度下降(gradient descent):

SGD(stochastic GD)随机梯度下降:对一个样本做梯度下降

batch梯度下降:使用所有样本做梯度下降(做一次又叫epoch)

mini-batch梯度下降:用子训练集做梯度下降

epoch:对整个训练集做了一次梯度下降

iteration:做了一次梯度下降

batch梯度下降、随机梯度下降、mini-batch梯度下降:这3个梯度下降的区别仅仅在于它们每次学习的样本数量不同。 无论是哪种梯度下降,学习率都是必须要精心调的。 通常来说,如果数据集很大,那么mini-batch梯度下降会比另外2种要高效。

mini-batch生成步骤(X,Y同步进行):
1、洗牌:随机调换样本顺序
2、分割:根据mini-batch-size切割

其中一列关于numpy分割的示例:

def func_test():# 4个样本,两个特征: 两行4列arr = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])print(arr)print("-----样本顺序调整-----")print(arr[:, [1, 0, 2]])print("-----生成随机组合-----")permutation = list(np.random.permutation(3))print(permutation)print("-----样本顺序根据随机调整-----")print(arr[:, permutation])print("-----样本批量抽取-----")print(arr[:, 1:3])
[[1 2 3 4][5 6 7 8]]
-----样本顺序调整-----
[[2 1 3][6 5 7]]
-----生成随机组合-----
[1, 0, 2]
-----样本顺序根据随机调整-----
[[2 1 3][6 5 7]]
-----样本批量抽取-----
[[2 3][6 7]]
def random_mini_batches(X, Y, mini_batch_size=64, seed=0):np.random.seed(seed)# 样本数量m = X.shape[1]mini_batches = []# 第一步:洗牌permutation = list(np.random.permutation(m))  # 随机生成m内的整数,例如m=5,则生成 [2,4,1,3,0]shuffled_X = X[:,permutation]shuffled_Y = Y[:,permutation]# 第二步:分割num_complete_mini_batches = math.floor(m / mini_batch_size)  # 向下取整for k in range(0, num_complete_mini_batches):start_index = mini_batch_size * kend_index = mini_batch_size * (k + 1)mini_batch_X = shuffled_X[:, start_index:end_index]mini_batch_Y = shuffled_Y[:, start_index:end_index]mini_batch = (mini_batch_X, mini_batch_Y)mini_batches.append(mini_batch)if m % mini_batch_size != 0:# 最后剩余的不足mini_batch_size的样本mini_batch_X = shuffled_X[:, num_complete_mini_batches * mini_batch_size:]mini_batch_Y = shuffled_Y[:, num_complete_mini_batches * mini_batch_size:]mini_batches.append((mini_batch_X, mini_batch_Y))return mini_batches
二、如何为mini-batch选择合理的batch size

batch size 对网络的影响:

1、没有batch size(全训练集),梯度准确,只适用于小样本的数据

2、batch size = 1,随机梯度下降,梯度变来变去,非常不准确,网络很难收敛

3、batch size增大,梯度变准确(mini-batch)

4、batch size增大,梯度已经非常准确,再增大也没用。

随机梯度下降、batch梯度下降会使得梯度的准确度处于两个极端,而mini-batch处于两个极端之间。

batch size也是一个超参数,需要根据成本变化来调整。一般来说batch size选择为2的n次方,2、4、8…1024…,这样会使得计算机运算的快些。常见的batch size有:64、512。

mini-batch的不足:

batch梯度下降因为梯度准确,则成本变化较准确,成本下降曲线平滑。而mini-batch的梯度下降,会不断趋于准确,但整个过程中,会因为批次的变化(更换了样本),有抬升的地方,即成本曲线震荡下行。而优化的方式,则是动量梯度下降、RMSprop、Adam优化算法。

三、指数加权平均

又名指数加权移动平均,是一种常用的序列数据处理方式,本质是通过计算局部的平均值,来描述数值的变化趋势。可以用来绘制趋势曲线。

核心公式:Vt = k* V[t-1] + (1-k) * Wt,k是一个超参数,决定了v值应该受前面多少个(1 / (1-k) )数据的影响。k越大,则说明受影响前面数据的个数越多。 而计算结果vt则可以理解为前多少个的近似平均值(非真实平均值)

示例1:

当天人民币汇率趋势 = 0.9 * 前一天人民币汇率 + 0.1 * 当天人民币汇率。此时k = 0.9,表示受前面10天的影响。

示例2:当k=0.9时,求的结果为前100天的温度趋势:

v100 = 0.9*v99 + 0.1*w100v99 = 0.9*v98 + 0.1*w99v98 = 0.9*v97 + 0.1*w98...v1 = 0.9*v0 + 0.1*w1

把v99代入v100,则:

v100 = 0.9*(0.9*v98 + 0.1*w99) + 0.1*w100
=0.1*w100 + 0.1*0.9*w99 + 0.9*0.9*v98 
=0.1*w100 + 0.1*0.9*w99 + 0.9*0.9*(0.9*v97 + 0.1*w98) 
=0.1*w100 + 0.1*0.9*w99 + (0.9^3)*v97 + 0.1* (0.9^2)*w98 

v100 = 0.1w100 + 0.1*0.9*w99 + 0.1*(0.9)^2*w98 + 0.1*(0.9)^3*w97 +...
  • 可以看出,前100天温度由一小部分拼凑而成,越往前权重越小,也就是说越来越不受前面数据的影响。
  • 0.1 约等于 0.1乘0.9 约等于 0.1乘0.9平方…,而10个约等于加起来=1。所以v值相当于前10天的平均值。
  • 如果k = 0.98,那么要50个0.02才等于1,也就是说vt相当于前面50天平均值。

计算指数加权平均:

for i in range(t)v0 = 0v1 = 0.98v0 + 0.02w1v2 = 0.98v1 + 0.02w2...
  • 修正算法:

在计算指数加权平均时,假设w1为40度,w2为40度,那么 v1 = 0.8,v2 = 0.98*0.8 + 0.8 = 1.584,说明前面的数值与实际值会相差很远。此时就需要修正。用公式 vt = vt / (1-k^t),此时v1 = 0.8 / (1-0.98^1) = 40。 后面随着t越来越大,分母越来越接近1,故vt就不需要修正了。 另外因为只是前面的会偏离一部分,故一般情况下也不会去修正。

四、动量梯度下降

标准梯度下降:

w = w - r*dw
b = b - r*db

因为加权指数移动平均,可以反应趋势,平均项越多,绘制的指数加权平均曲线变化更为缓慢。故我们可以用它来做梯度下降,从而减轻标准梯度下降跳来跳去找,找到损失最低点的性能浪费。(比如跳动过大,错过最低点。 以及在最小值左右来回跳动)故动量梯度下降时优于标准梯度下降的算法。

vdw = k *vdw + (1-k)*dw

vdb = k *vdb + (1-k)*db

等式左边的vdw、vdb为当前值,等式右边的为前一个值。故引出动量梯度下降:

w = w - r*vdw
b = b - r*vdb

其优点:

1、动量移动的更快

2、动量有机会逃脱局部极小值和高原区。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 微软蓝屏事件揭示的网络安全深层问题与未来应对策略
  • 【Unity】web gl inputFied 中文输入,同时支持TextMeshInputFied,支持全屏
  • Redis过期键的删除策略
  • 【数据结构】栈和队列(c语言实现)(附源码)
  • 学python的第一天:PyCharm创建项目
  • kickstart自动安装脚本
  • 通信原理实验——PCM编译码
  • 什么是V2X?
  • Vue+live2d实现虚拟人物互动(一次体验叙述)
  • RocketMQ 的消息存储机制
  • 3.4数组和特殊矩阵
  • Java开发:文件上传和下载
  • 按摩虎口穴位的作用
  • Laravel php框架与Yii php 框架的优缺点
  • 上线前端系统
  • [deviceone开发]-do_Webview的基本示例
  • “Material Design”设计规范在 ComponentOne For WinForm 的全新尝试!
  • 【159天】尚学堂高琪Java300集视频精华笔记(128)
  • 【跃迁之路】【641天】程序员高效学习方法论探索系列(实验阶段398-2018.11.14)...
  • 2018以太坊智能合约编程语言solidity的最佳IDEs
  • Brief introduction of how to 'Call, Apply and Bind'
  • docker python 配置
  • git 常用命令
  • Hibernate最全面试题
  • java8 Stream Pipelines 浅析
  • Java反射-动态类加载和重新加载
  • java小心机(3)| 浅析finalize()
  • Java新版本的开发已正式进入轨道,版本号18.3
  • MySQL用户中的%到底包不包括localhost?
  • orm2 中文文档 3.1 模型属性
  • Terraform入门 - 3. 变更基础设施
  • 翻译:Hystrix - How To Use
  • 设计模式走一遍---观察者模式
  • 适配iPhoneX、iPhoneXs、iPhoneXs Max、iPhoneXr 屏幕尺寸及安全区域
  • 数据科学 第 3 章 11 字符串处理
  • 字符串匹配基础上
  • 3月27日云栖精选夜读 | 从 “城市大脑”实践,瞭望未来城市源起 ...
  • 格斗健身潮牌24KiCK获近千万Pre-A轮融资,用户留存高达9个月 ...
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​探讨元宇宙和VR虚拟现实之间的区别​
  • ​直流电和交流电有什么区别为什么这个时候又要变成直流电呢?交流转换到直流(整流器)直流变交流(逆变器)​
  • # Apache SeaTunnel 究竟是什么?
  • # Spring Cloud Alibaba Nacos_配置中心与服务发现(四)
  • # 深度解析 Socket 与 WebSocket:原理、区别与应用
  • #HarmonyOS:Web组件的使用
  • #NOIP 2014#Day.2 T3 解方程
  • #stm32驱动外设模块总结w5500模块
  • $nextTick的使用场景介绍
  • $var=htmlencode(“‘);alert(‘2“); 的个人理解
  • ( 用例图)定义了系统的功能需求,它是从系统的外部看系统功能,并不描述系统内部对功能的具体实现
  • (3)(3.5) 遥测无线电区域条例
  • (done) 两个矩阵 “相似” 是什么意思?
  • (k8s)Kubernetes本地存储接入
  • (附源码)计算机毕业设计ssm本地美食推荐平台
  • (接口封装)