当前位置: 首页 > news >正文

【pytorch01】简单回归问题

1.梯度下降(Gradient Descent)

梯度下降
y = x 2 ∗ s i n ( x ) y=x^{2}*sin(x) y=x2sin(x)
y ′ = 2 ∗ x ∗ s i n ( x ) + x 2 ∗ c o s ( x ) y'=2*x*sin(x) + x^{2}*cos(x) y=2xsin(x)+x2cos(x)
求最小值要求导

梯度下降定义:梯度下降要迭代计算,每一次得到一个导数以后,用原来的x减去该x处导数的值,得到一个新的x的值就是这样一个迭代的过程

x t = x t − 1 − η ∂ y ∂ x t − 1 x_{t}=x_{t-1}-η\frac{\partial{y}}{\partial x_{t-1}} xt=xt1ηxt1y

η就是learning rate(学习率),可以通过调整学习率够使目标函数在合适的时间内收敛到局部最小值。

  • y = w ∗ x + b y=w*x+b y=wx+b
    • 1.567 = w ∗ 1 + b 1.567 = w * 1 + b 1.567=w1+b
    • 3.043 = w ∗ 2 + b 3.043 = w * 2 + b 3.043=w2+b

w = 1.477
b = 0.089
通过消元法,此时w和b是一个准确解,被称之为Closed Form Solution

其实现实生活中可以精确求解的东西不多,我们现实生活中拿到的数据都是有一定偏差的,因此对于实际的问题,与其说求一个Closed Form Solution(封闭解),不如求得一个近似解,这个近似解在经验上可行,这样就可以达到我们的目的

用高斯噪声(均值为0.01,方差为1)模仿偏差(现实生活中拿到的数据都是带有一定噪声的)
y = w ∗ x + b + ϵ y=w *x+b + \epsilon y=wx+b+ϵ
ϵ ∼ N ( 0.01 , 1 ) \epsilon\sim N(0.01,1) ϵN(0.01,1)
1.567 = w ⋆ 1 + b + e p s 3.043 = w ⋆ 2 + b + e p s 4.519 = w ⋆ 3 + b + e p s . . . 1.567=w^{\star}1+b+eps\\3.043=w^{\star}2+b+eps\\4.519=w^{\star}3+b+eps\\... 1.567=w1+b+eps3.043=w2+b+eps4.519=w3+b+eps...
观测一组数据,通过观测这一组数据来求解,这一组数据中整体表现比较好的解,虽然不是Closed Form Solution,但是证明了有良好的表现,可以达到需求。

y = x 2 ∗ s i n ( x ) y=x^{2}*sin(x) y=x2sin(x)使用梯度下降算法是求这个函数的最小值

但是对于 y = w ∗ x + b y=w*x+b y=wx+b这个方程来说并不是要求y的最小值,而是要求真实的y和 w ∗ x + b w*x+b wx+b的差最小,因为希望 w ∗ x + b w*x+b wx+b更加接近真实的y的值

可以通过求 l o s s = ( w ∗ x + b − y ) 2 loss=(w*x+b -y)^2 loss=(wx+by)2的极小值,可以达到接近的目的,获取此时的w和b的值

图片

2.实战

l o s s = ( W X + b − y ) 2 loss=(WX+b-y)^2 loss=(WX+by)2

# 返回average loss
def compute_error_for_line_given_points(w,b,points):lossTotal = 0for i in range(len(points)):x = points[i,0]y = points[i,1]lossTotal += (y - (w * x + b))** 2return lossTotal / float(len(points))

w ′ = w − l r ∗ ∇ l o s s ∇ w w'=w-lr*\frac{\nabla loss}{\nabla w} w=wlrwloss

# 要求loss的极小值,对w和b分别梯度下降
def step_gradient(b_current,w_current,points,learningRate):b_gradient = 0w_gradient = 0N = float(len(points))for i in range(len(points)):x = points[i, 0]y = points[i, 1]# loss函数分别对w和b求导# 多了N的原因是因为对所有点的导数累加起来,这样就不用做average了# 此时获得的w和b是所有点average之后的梯度w_gradient += -(2/N) * x * (y - (w_current * x + b_current))b_gradient += -(2/N) * (y - (w_current * x + b_current))new_b = b_current - (learningRate * b_gradient)new_w = w_current - (learningRate * w_gradient)return [new_w,new_b]

经过多次梯度下降得到最优解

def gradient_descent_runner(points,starting_w,starting_b,learning_rate,num_iterations):w = starting_wb = starting_bfor i in range(num_iterations):w,b = step_gradient(w,b,np.array(points),learning_rate)return [w,b]
def run():points = np.genfromtxt("data.csv",delimiter=",")print(points[:10])learning_rate = 0.0001initial_w = 0initial_b = 0num_iterations = 1000print("Starting gradient descent at w = {0},b = {1},error = {2}".format(initial_w,initial_b,compute_error_for_line_given_points(initial_w,initial_b,points)))print("Running...")[w,b] = gradient_descent_runner(points,initial_w,initial_b,learning_rate,num_iterations)print("After {0} iterations w = {1},b = {2},error = {3}".format(num_iterations,w, b,compute_error_for_line_given_points(w, b, points)))if __name__ == '__main__':run()

结果
最终的数据与Closed Form Solution非常接近

相关文章:

  • 空间复杂度 线性表,顺序表尾插。
  • 离线linux通过USB连接并使用手机网络
  • 初学者应该掌握的MySQL数据库的基本组成部分及概念
  • 【Docker】——安装镜像和创建容器,详解镜像和Dockerfile
  • 【Qt】QList<QVariantMap>中数据修改
  • ic基础|功耗篇03:ic设计人员如何在代码中降低功耗?一文带你了解行为级以及RTL级低功耗技术
  • 指纹浏览器与虚拟机的区别及在跨境电商中的应用
  • LeetCode 每日一题 2024/6/17-2024/6/23
  • ChatGPT 简介
  • 日语 13 14
  • ping命令返回结果实例分析
  • K8S - 理解ClusterIP - 集群内部service之间的反向代理和loadbalancer
  • 深入解析Linux Bridge:原理、架构、操作与持久化配置
  • PAL: Program-aided Language Models
  • Python爬虫实战案例之——MySql数据入库
  • canvas 绘制双线技巧
  • E-HPC支持多队列管理和自动伸缩
  • ES6系统学习----从Apollo Client看解构赋值
  • hadoop入门学习教程--DKHadoop完整安装步骤
  • JavaScript 奇技淫巧
  • JavaScript 是如何工作的:WebRTC 和对等网络的机制!
  • Javascript 原型链
  • leetcode378. Kth Smallest Element in a Sorted Matrix
  • Mocha测试初探
  • PhantomJS 安装
  • Redis学习笔记 - pipline(流水线、管道)
  • 阿里中间件开源组件:Sentinel 0.2.0正式发布
  • 关于使用markdown的方法(引自CSDN教程)
  • 前端技术周刊 2019-02-11 Serverless
  • 浅谈JavaScript的面向对象和它的封装、继承、多态
  • 如何打造100亿SDK累计覆盖量的大数据系统
  • 通过来模仿稀土掘金个人页面的布局来学习使用CoordinatorLayout
  • 用Canvas画一棵二叉树
  • 用jquery写贪吃蛇
  • 自制字幕遮挡器
  • python最赚钱的4个方向,你最心动的是哪个?
  • Spark2.4.0源码分析之WorldCount 默认shuffling并行度为200(九) ...
  • 扩展资源服务器解决oauth2 性能瓶颈
  • ​3ds Max插件CG MAGIC图形板块为您提升线条效率!
  • #QT(TCP网络编程-服务端)
  • (1)STL算法之遍历容器
  • (delphi11最新学习资料) Object Pascal 学习笔记---第13章第6节 (嵌套的Finally代码块)
  • (done) 两个矩阵 “相似” 是什么意思?
  • (ibm)Java 语言的 XPath API
  • (转)ORM
  • **PHP二维数组遍历时同时赋值
  • .NET Core WebAPI中使用Log4net 日志级别分类并记录到数据库
  • .net 生成二级域名
  • .NET 自定义中间件 判断是否存在 AllowAnonymousAttribute 特性 来判断是否需要身份验证
  • .netcore 获取appsettings
  • .NetCore项目nginx发布
  • .NET国产化改造探索(三)、银河麒麟安装.NET 8环境
  • .pub是什么文件_Rust 模块和文件 - 「译」
  • .secret勒索病毒数据恢复|金蝶、用友、管家婆、OA、速达、ERP等软件数据库恢复
  • .set 数据导入matlab,设置变量导入选项 - MATLAB setvaropts - MathWorks 中国