当前位置: 首页 > news >正文

优化算法 - Adam算法

文章目录

  • Adam算法
    • 1 - 算法
    • 2 - 实现
    • 3 - Yogi
    • 4 - 小结

Adam算法

在本章中,我们已经学习了许多有效优化的技术。在本节讨论之前,我们先详细回顾以下这些技术:

  • 随机梯度下降:在解决优化问题时比梯度下降更有效
  • 小批量随机梯度下降:在一个小批量中使用更大的观测值集,可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键
  • 动量法:添加了一种机制,用于汇总过去梯度的历史以加速收敛
  • AdaGrad算法:对每个坐标缩放来实现高效计算的预处理器
  • RMSProp算法:通过学习率的调整来分离每个坐标的缩放

Adam算法将所有这些技术汇总到一个高效的学习算法中。不出预料,作为深度学习中使用的更强大和有效的优化算法之一,它非常受欢迎。但是它并非没有问题,尤其是 [Reddi et al., 2019]表明,有时Adam算法可能由于⽅差控制不良⽽发散。在完善⼯作中,[Zaheer et al., 2018]给Adam算法提供了⼀个称为Yogi的热补丁来解决这些问题。下⾯我们了解⼀下Adam算法

1 - 算法

2 - 实现

从头开始实现Adam算法并不难,为了方便起见,我们将时间步t存储在hyperparams字典中。除此之外,一切都很简单

%matplotlib inline
import torch
from d2l import torch as d2l

def init_adam_states(feature_dim):
    v_w,v_b = torch.zeros((feature_dim,1)),torch.zeros(1)
    s_w,s_b = torch.zeros((feature_dim,1)),torch.zeros(1)
    return ((v_w,s_w),(v_b,s_b))

def adam(params,states,hyperparams):
    beta1,beta2,eps = 0.9,0.999,1e-6
    for p,(v,s) in zip(params,states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = beta2 * s + (1 - beta2) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr + eps))
        p.grad.data.zero_()
    hyperparams['t'] += 1

现在,我们用以上Adam算法来训练模型,这里我们使用η=0.01的学习率

data_iter,feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(adam,init_adam_states(feature_dim),{'lr':0.01,'t':1},data_iter,feature_dim);
loss: 0.246, 0.014 sec/epoch

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NF6jVSDf-1663328054526)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209161925007.svg)]

此外,我们可以用深度学习框架自带算法应用Adam算法,这里我们只需要传递配置参数

trainer = torch.optim.Adam
d2l.train_concise_ch11(trainer,{'lr':0.01},data_iter)
loss: 0.247, 0.015 sec/epoch

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N0TofKyJ-1663328054527)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209161925008.svg)]

3 - Yogi



def yogi(params,states,hyperparams):
    beta1,beta2,eps = 0.9,0.999,1e-3
    for p,(v,s) in zip(params,states):
        with torch.no_grad():
            v[:] = beta1 * v + (1 - beta1) * p.grad
            s[:] = s + (1 - beta2) * torch.sign(torch.square(p.grad) -s ) * torch.square(p.grad)
            v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
            s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
            p[:] -= hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr) + eps)
        p.grad.data.zero_()
    hyperparams['t'] += 1
    
data_iter,feature_dim = d2l.get_data_ch11(batch_size=10)
d2l.train_ch11(yogi,init_adam_states(feature_dim),{'lr':0.01,'t':1},data_iter,feature_dim)
loss: 0.244, 0.007 sec/epoch





([0.006999015808105469,
  0.01399993896484375,
  0.02099919319152832,
  0.026999235153198242,
  0.03399944305419922,
  0.0410001277923584,
  0.04800128936767578,
  0.05700254440307617,
  0.06400370597839355,
  0.07200503349304199,
  0.07900643348693848,
  0.08518671989440918,
  0.09218716621398926,
  0.10018706321716309,
  0.10867691040039062],
 [0.3831201309363047,
  0.30505007115999855,
  0.27388086752096813,
  0.25824862279494604,
  0.248792000691096,
  0.24663881778717042,
  0.24533938866853713,
  0.24811744292577106,
  0.2440877826611201,
  0.24333851114908855,
  0.24304762629667917,
  0.24334035567442577,
  0.24402384889125825,
  0.24259794521331787,
  0.2435852948029836])


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cptXQbPI-1663328054528)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209161925009.svg)]

4 - 小结

  • Adam算法将许多优化算法的功能结合到了相当强大的更新规则中
  • Adam算法在RMSProp算法基础上创建的,还在小批量的随机梯度上使用EWMA
  • 在估计动量和二次矩时,Adam算法使用偏差校正来调整缓慢的启动速度
  • 对于具有显著差异的梯度,我们可能会遇到收敛性问题,我们可以通过使用更大的小批量或者切换到改进的估计值 s t s_t st来修正它们。Yogi提供了这样的替代方案

相关文章:

  • Open3D (C++) 点云变换
  • 黑白照片修复彩色软件免费有哪些?分享这三个实用的软件给你
  • CSS基础入门手册
  • python-- for循环的基础语法
  • _Linux进程控制
  • vue3.0--1.vue3.0环境集成、setup、ref函数、reactive函数、计算属性(computed)
  • 基于Opencv5.x(C++)流媒体视频流实现网页浏览器人脸检测
  • 网络安全——XSS跨站脚本攻击
  • AT24C02存储与读取数据
  • Linux高级编程--gdb调试
  • 家校协同小程序实战教程
  • 沉睡者C - 想要通过网上来赚钱,悟性很重要
  • Java集合面试小结(2)
  • 【uiautomation】微信群发消息,可发送文本 文件
  • 【network】windows 获取Adapter 名称
  • IE9 : DOM Exception: INVALID_CHARACTER_ERR (5)
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 【mysql】环境安装、服务启动、密码设置
  • CNN 在图像分割中的简史:从 R-CNN 到 Mask R-CNN
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • iOS动画编程-View动画[ 1 ] 基础View动画
  • JAVA之继承和多态
  • Js基础知识(四) - js运行原理与机制
  • JS数组方法汇总
  • linux学习笔记
  • October CMS - 快速入门 9 Images And Galleries
  • Webpack 4 学习01(基础配置)
  • windows下mongoDB的环境配置
  • 力扣(LeetCode)22
  • 买一台 iPhone X,还是创建一家未来的独角兽?
  • 漫谈开发设计中的一些“原则”及“设计哲学”
  • 想使用 MongoDB ,你应该了解这8个方面!
  • 用jquery写贪吃蛇
  • 1.Ext JS 建立web开发工程
  • 阿里云ACE认证学习知识点梳理
  • (9)YOLO-Pose:使用对象关键点相似性损失增强多人姿态估计的增强版YOLO
  • (Matalb时序预测)WOA-BP鲸鱼算法优化BP神经网络的多维时序回归预测
  • (ros//EnvironmentVariables)ros环境变量
  • (二)c52学习之旅-简单了解单片机
  • (原創) 如何優化ThinkPad X61開機速度? (NB) (ThinkPad) (X61) (OS) (Windows)
  • (转)详解PHP处理密码的几种方式
  • (转载)(官方)UE4--图像编程----着色器开发
  • ***检测工具之RKHunter AIDE
  • .NET CORE Aws S3 使用
  • .net core 微服务_.NET Core 3.0中用 Code-First 方式创建 gRPC 服务与客户端
  • .Net Web窗口页属性
  • .net6 webapi log4net完整配置使用流程
  • .net遍历html中全部的中文,ASP.NET中遍历页面的所有button控件
  • .net和jar包windows服务部署
  • .net生成的类,跨工程调用显示注释
  • .NET项目中存在多个web.config文件时的加载顺序
  • [.net]官方水晶报表的使用以演示下载
  • [@Controller]4 详解@ModelAttribute
  • [2016.7.Test1] T1 三进制异或
  • [2019.3.20]BZOJ4573 [Zjoi2016]大森林