当前位置: 首页 > news >正文

图解梯度下降背后的数学原理

2019独角兽企业重金招聘Python工程师标准>>> hot3.png

敏捷在软件开发过程中是一个非常著名的术语,它背后的基本思想很简单:快速构建一些东西,然后得到一些反馈,根据反馈做出改变,重复此过程。目标是让产品更贴合用,让用户做出反馈,以获得设计开发出的产品与优秀的产品二者之间误差最小,梯度下降算法背后的原理和这基本一样。

目的

梯度下降算法是一个迭代过程,它将获得函数的最小值。下面的公式将整个梯度下降算法汇总在一行中。

b566c1de14c7b3b352a965804b82aa54527.jpg
但是这个公式是如何得出的呢?实际上很简单,只需要具备一些高中的数学知识即可理解。本文将尝试讲解这个公式,并以线性回归模型为例,构建此类公式。

机器学习模型

  • 考虑二维空间中的一堆数据点。假设数据与一组学生的身高和体重有关。试图预测这些数量之间的某种关系,以便我们可以预测一些新生的体重。这本质上是一种有监督学习的简单例子。
  • 现在在空间中绘制一条穿过其中一些数据点的任意直线,该直线方程的形如Y=mX+b,其中m是斜率,b是其在Y轴的截距。

ff103dafdacfe2a60d5ace79e2569ebb4f9.jpg

预测

给定一组已知的输入及其相应的输出,机器学习模型试图对一组新的输入做出一些预测。

e6360c9010659fe5fe75c91cddbdac2ca7c.jpg
两个预测之间的差异即为错误。

4777ce0938738d8249c4646b776ac1cf7ba.jpg
这涉及成本函数或损失函数的概念(cost function or loss function)。

成本函数

成本函数/损失函数用来评估机器学习算法的性能。二者的区别在于,损失函数计算单个训练示例的错误,而成本函数是整个训练集上错误的平均值。

成本函数基本上能告诉我们模型在给定m和b的值时,其预测能“有多好”。

比方说,数据集中总共有N个点,我们想要最小化所有N个数据点的误差。因此,成本函数将是总平方误差,即

afc1ca099d01ca3f510004300247b281bb2.jpg

为什么采取平方差而不是绝对差?因为平方差使得导出回归线更容易。实际上,为了找到这条直线,我们需要计算成本函数的一阶导数,而计算绝对值的导数比平方值更难。

最小化成本函数

任何机器学习算法的目标都是最小化成本函数。

这是因为实际值和预测值之间的误差对应着表示算法在学习方面的性能。由于希望误差值最小,因此尽量使得那些mb值能够产生尽可能小的误差。

如何最小化一个任意函数?

仔细观察上述的成本函数,其形式为Y=X²。在笛卡尔坐标系中,这是一个抛物线方程,用图形表示如下:

6a404857663d7db5ef41a823b8a746f7854.jpg
为了最小化上面的函数,需要找到一个x,函数在该点能产生小值Y,即图中的红点。由于这是一个二维图像,因此很容易找到其最小值,但是在维度比较大的情况下,情况会更加复杂。对于种情况,需要设计一种算法来定位最小值,该算法称为梯度下降算法(Gradient Descent)。
 

梯度下降

梯度下降是优化模型的方法中最流行的算法之一,也是迄今为止优化神经网络的最常用方法。它本质上是一种迭代优化算法,用于查找函数的最小值。

表示

假设你是沿着下面的图表走,目前位于曲线'绿'点处,而目标是到达最小值,即点位置,但你是无法看到该最低点。

da82a932941746eeb2ed43a9a5a644dd8ef.jpg
可能采取的行动:

  • 可能向上或向下;
  • 如果决定走哪条路,可能会采取更大的步伐或小的步伐来到达目的地;

从本质上讲,你应该知道两件事来达到最小值,即走哪条和走多远。

梯度下降算法通过使用导数帮助我们有效地做出这些决策。导数是来源于积分,用于计算曲线特定点处的斜率。通过在该点处绘制图形的切线来描述斜率。因此,如果能够计算出这条切线,可能就能够计算达到最小值的所需方向。

最小值

在下图中,在绿点处绘制切线,如果向上移动,就将远离最小值,反之亦然。此外,切线也能让我们感觉到斜坡的陡峭程度。

9dd4cfc01fc99d0b48b34777dc05871a39d.jpg
蓝点处的斜率比绿点处的斜率低,这意味着从蓝点到绿点所需的步长要小得多。

成本函数的数学解释

现在将上述内容纳入数学公式中。在等式y=mX+b中,mb是其参数。在训练过程中,其值也会发生微小变化,用δ表示这个小的变化。参数值将分别更新为m = m-δm 和b = b-δb。最终目标是找到mb的值,以使得y=mx+b 的误差最小,即最小化成本函数。
重写成本函数:

59ed19ee6f5a810967374310bedd0edf7b8.jpg

想法是,通过计算函数的导数/斜率,就可以找到函数的最小值。

学习率

达到最小值或最低值所采取的步长大小称为学习率。学习率可以设置的比较大,但有可能会错过最小值。而另一方面,小的学习率将花费大量时间训练以达到最低点。
下面的可视化给出了学习率的基本概念。在第三个图中,以最小步数达到最小点,这表明该学习率是此问题的最佳学习率。

aa8c724ab20e7121ee3435cede24aa59944.jpg

从上图可以看到,当学习率太低时,需要花费很长训练时间才能收敛。而另一方面,当学习率太高时,梯度下降未达到最小值,如下面所示:

11

导数

机器学习在优化问题中使用导数。梯度下降等优化算法使用导数来决定是增加还是减少权重,进而增加或减少目标函数。
如果能够计算出函数的导数,就可以知道在哪个方向上能到达最小化。
主要处理方法源自于微积分中的两个基本概念:

  • 指数法则
    指数法则求导公式:

2f553e08fc5204b55483682cf9f949dd7ba.jpg

  • 链式法则
    链式法则用于计算复合函数的导数,如果变量z取决于变量y,且它本身也依赖于变量x,因此y和z是因变量,那么z对x的导数也与y有,这称为链式法则,在数学上写为:

7d8b07e7f08fb5d8e262f9d9c348877d00c.jpg

举个例子加强理解:

5f6f309e2450efc6692c46f6168f44e6403.jpg
使用指数法则和链式发规,计算成本函数相对于m和c的变化方式。这涉及偏导数的概念,即如果存在两个变量的函数,那么为了找到该函数对其中一个变量的偏导数,需将另一个变量视为常数。举个例子加强理解:

7bcf8e47e6a5681b7af63f2c8c0e6203f16.jpg

计算梯度下降

现在将这些微积分法则的知识应用到原始方程中,并找到成本函数的导数,即mb。修改成本函数方程:

293a47d5153390c8acc0c3c13b4d755fef1.jpg
为简单起见,忽略求和符号。求和部分其实很重要,尤其是随机梯度下降(SGD)与批量梯度下降的概念。在批量梯度下降期间,我们一次查看所有训练样例的错误,而在SGD中一次只查看其中的一个错误。这里为了简单起见,假设一次只查看其中的一个错误:

bf753ec11098c6d96beac74411174510251.jpg
现在计算误差对m和b的梯度:

bea349c8199862a470ffbd372750382dae5.jpg
将值对等到成本函数中并将其乘以学习率:

e4ac8f04fc028b29e457af0c68c3a27e6f6.jpg
其中这个等式中的系数项2是一个常数,求导时并不重要,这里将其忽略。因此,最终,整篇文章归结为两个简单的方程式,它们代表了梯度下降的方程。

483c8164a95f93b8cc01b5995136b49e01f.jpg
其中是下一个位置的参数;m⁰b⁰是当前位置的参数。

因此,为了求解梯度,使用新的mb值迭代数据点并计算偏导数。这个新的梯度会告诉我们当前位置的成本函数的斜率以及我们应该更新参数的方向。另外更新参数的步长由学习率控制。

结论

本文的重点是展示梯度下降的基本概念,并以线性回归为例讲解梯度下降算法。通过绘制最佳拟合线来衡量学生身高和体重之间的关系。但是,这里为了简单起见,举的例子是机器学习算法中较简单的线性回归模型,读者也可以将其应用到其它机器学习方法中。


原文链接
本文为云栖社区原创内容,未经允许不得转载。

转载于:https://my.oschina.net/u/1464083/blog/3028452

相关文章:

  • CentOS 6.9下PXE+Kickstart无人值守安装操作系统附常见问题
  • 3月27日云栖精选夜读 | 从 “城市大脑”实践,瞭望未来城市源起 ...
  • ES6 proxy
  • GocatorSDK学习笔记
  • 好程序员web前端分享JavaScript学习指南
  • MYSQL一个优化的过程
  • 小白应该如何快速入门阿里云服务器,新手使用ECS的方法 ...
  • React:输入框新增/取消一行如何处理(X-mind图)
  • K8S 生态周报| 2019.03.25~2019.03.31
  • 无人车制胜关键:Apollo决策系统全面剖析
  • 机器人开始自主学习,是人类福祉,还是定时炸弹? ...
  • pika开源:替代WebPack的全新JS构建工具
  • MySql批量插入与唯一索引问题
  • CF451E Devu and Flowers
  • Android之RxJava详解
  • 2018天猫双11|这就是阿里云!不止有新技术,更有温暖的社会力量
  • create-react-app做的留言板
  • iOS帅气加载动画、通知视图、红包助手、引导页、导航栏、朋友圈、小游戏等效果源码...
  • Java Agent 学习笔记
  • Java反射-动态类加载和重新加载
  • Linux Process Manage
  • Linux后台研发超实用命令总结
  • Promise面试题2实现异步串行执行
  • Spring-boot 启动时碰到的错误
  • Vue 动态创建 component
  • vue-router 实现分析
  • Zsh 开发指南(第十四篇 文件读写)
  • 区块链分支循环
  • 使用前端开发工具包WijmoJS - 创建自定义DropDownTree控件(包含源代码)
  • 数据仓库的几种建模方法
  • 我感觉这是史上最牛的防sql注入方法类
  • 策略 : 一文教你成为人工智能(AI)领域专家
  • ​ 全球云科技基础设施:亚马逊云科技的海外服务器网络如何演进
  • ​secrets --- 生成管理密码的安全随机数​
  • #100天计划# 2013年9月29日
  • #使用清华镜像源 安装/更新 指定版本tensorflow
  • (C++)栈的链式存储结构(出栈、入栈、判空、遍历、销毁)(数据结构与算法)
  • (day6) 319. 灯泡开关
  • (done) NLP “bag-of-words“ 方法 (带有二元分类和多元分类两个例子)词袋模型、BoW
  • (阿里巴巴 dubbo,有数据库,可执行 )dubbo zookeeper spring demo
  • (初研) Sentence-embedding fine-tune notebook
  • (二) Windows 下 Sublime Text 3 安装离线插件 Anaconda
  • (附源码)springboot“微印象”在线打印预约系统 毕业设计 061642
  • (区间dp) (经典例题) 石子合并
  • (译)2019年前端性能优化清单 — 下篇
  • (转)Linux整合apache和tomcat构建Web服务器
  • (转)Sql Server 保留几位小数的两种做法
  • .360、.halo勒索病毒的最新威胁:如何恢复您的数据?
  • .NET 使用 ILMerge 合并多个程序集,避免引入额外的依赖
  • .考试倒计时43天!来提分啦!
  • /etc/shadow字段详解
  • @CacheInvalidate(name = “xxx“, key = “#results.![a+b]“,multi = true)是什么意思
  • [145] 二叉树的后序遍历 js
  • [④ADRV902x]: Digital Filter Configuration(发射端)
  • [AIGC] Spring Interceptor 拦截器详解