当前位置: 首页 > news >正文

pytorch(1)梯度下降

一、系列文章目录

(1)梯度下降
(2)手写数字识别引入&Pytorch 数据类型
(3)创建Tensor
(4)Broadcasting
(5)Tensor
(6)Tensor统计
(7)Where和Gather
(8)函数的梯度
(9)loss函数;自动求导


文章目录

  • 一、系列文章目录
  • 二、梯度下降
    • 1.什么是梯度下降
    • 2.梯度下降算法核心公式:
    • 3.求解方法
    • 4.实例
    • 5.实战


二、梯度下降

1.什么是梯度下降

梯度gradient,是深度学习的核心精髓,许多专家称深度学习为gradient programming,deep learning其实就是求解一个巨复杂的函数,求解工具就是梯度下降算法。

2.梯度下降算法核心公式:

xn+1= xn - f`(x)* (learning rate)

其中learning rate通常取0.001,手写数据一般0.01。它的取值会影响下降的速度(步长),取值过大会出现大的抖动。
x不断更新,当x在最优解(极值点)附近时梯度会接近于0,x取值在一个小范围内抖动。
可以加不同的限制额外的优化,例如考虑上一次的前进方向与这一次的一致,得到不同的求解器。以提高精度和速度。

  • SGD: 传统随机梯度下降
  • Adam:一阶优化算法https://www.cnblogs.com/yifdu25/p/8183587.html

3.求解方法

传统方法通常求解精确解,例如,y= w·x+b ,x分别取值1和2,做差精确求解出w和b
但是真实问题中数据往往有噪声,因此求取近似解即可
y= w·x+b + e
e ~ N(0.01, 1)高斯噪声,均值0.01,方差为1
目标函数: loss=(w·x+b-y)2 求解w和b使loss取极小值

4.实例

y=1.477x+0.089+e
先根据这个模型生成100个数据,然后假装不知道模型的参数,根据这些数据求解出模型。

根据数据散点图分布的特点,假设一个线性模型 y= w·x+b + e
目标:Minimize loss=Σi(w·xi + b - yi)2 求解 w b

这个模型为数低,有全局极小值,是凸函数,相关问题称为凸优化。对于一个凹函数,则求解局部极小值。

linear regression:y值是连续的问题,股票指数等等。
logistic regression:在linear regression 上加一个激活函数,压缩到(0,1)上。手写数字、硬币二分类等等涉及到概率的问题。

5.实战

import numpy as np

def data():
    e = np.random.normal(0.01, 1, 1) #normal是正态分布 (均值,标准差,输出shape) shape默认为none,只输出一个值
    points = []
    for x in range(16):
        y = 1.477 * x + 0.089 + e
        points.append([]) #points 中添加一个空list
        points[x].append(x) #points第x位添加x
        points[x].append(y[0]) 
    return(points)

def loss(b, w, points):   # 计算loss函数
    totalError = 0
    for i in range(0,len(points)):
        x = points[i,0]     # 数组,在后面用np生成
        y = points[i,1]
        totalError += (y - (w * x + b)) ** 2  #平方和
    return totalError/float(len(points)) #取均值

def step_gradient(b_current, w_current, points, learningRate): #参数更新
    b_gradient = 0 #梯度初始值
    w_gradient = 0 #梯度初始值
    N = float(len(points)) #数据总数
    for i in range(0, len(points)):  # 计算b和w的梯度
        x = points[i, 0] #取出x
        y = points[i, 1] #取出y
        b_gradient += -(2/N) * (y - ((w_current * x) + b_current))# 均值 ;wx+b-y 变为 y-wx-b 添加负号
        w_gradient += -(2/N) * x * (y - ((w_current * x) + b_current))# w'= w-lr * (loss'w)  (loss'w)指loss对w的偏导
    new_b = b_current - (learningRate * b_gradient)   #梯度下降
    new_w = w_current - (learningRate * w_gradient)
    return [new_b,new_w] #更新参数

def gd_runner(points, initial_b, initial_w, learning_rate, num_iterations):
    b_current = initial_b #从初始值开始
    w_current = initial_w #从初始值开始
    for i in range(num_iterations): #更新num次
        b_current,w_current = step_gradient(b_current, w_current, np.array(points), learning_rate)
    return [b_current, w_current] #更新num次之后的参数

def run():
    #points = np.genfromtxt("data.csv", delimiter = ",") #Numpy.genfromtxt-读取csv文件数据  没用这句是因为我还不会生成文件哈哈哈哈哈哈
    points = np.array(data())
    learning_rate = 0.0001 #设置lr
    initial_b = 0 #初始值b
    initial_w = 0 #初始值w
    num_iterations = 100000 #学习次数
    print("Starting gradient decent at b ={0}, w = {1}, erroe = {2}".format(initial_b, initial_w, loss(initial_b,initial_w,points))) #初始误差
    print("running…………")
    [b, w] = gd_runner(points, initial_b, initial_w, learning_rate, num_iterations) #参数更新了1000次
    print("after {0} iterations b = {1}, w = {2}, error = {3}".format(num_iterations, b, w, loss(b, w, points)))

if __name__ == '__main__':
    run()

可以看到,结果不错
可以看到,拟合结果还是不错的,嘿

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 姐妹情深
  • pytorch(2)手写数字识别引入Pytorch 数据类型
  • 一辑非常好的人体外拍
  • pytorch(3)创建Tensor
  • pytorch(4)Broadcasting
  • 李亚石先生生前力作欣赏---花草
  • pytorch(5)Tensor
  • 李亚石先生生前力作欣赏---桂林山水
  • pytorch(6)Tensor统计
  • 千万不要死于无知——前言
  • pytorch(7)Where和Gather
  • 千万不要死于无知——平衡饮食
  • pytorch(8)函数的梯度
  • MOOC python数据分析(1)Numpy
  • MOOC python数据分析(2)数据存取/随机数函数
  • 《微软的软件测试之道》成书始末、出版宣告、补充致谢名单及相关信息
  • CoolViewPager:即刻刷新,自定义边缘效果颜色,双向自动循环,内置垂直切换效果,想要的都在这里...
  • Eureka 2.0 开源流产,真的对你影响很大吗?
  • JavaScript 一些 DOM 的知识点
  • Java新版本的开发已正式进入轨道,版本号18.3
  • Shell编程
  • vue:响应原理
  • web标准化(下)
  • 等保2.0 | 几维安全发布等保检测、等保加固专版 加速企业等保合规
  • 今年的LC3大会没了?
  • 使用 Xcode 的 Target 区分开发和生产环境
  • 体验javascript之美-第五课 匿名函数自执行和闭包是一回事儿吗?
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 为什么要用IPython/Jupyter?
  • 我有几个粽子,和一个故事
  • Spring第一个helloWorld
  • Unity3D - 异步加载游戏场景与异步加载游戏资源进度条 ...
  • 阿里云API、SDK和CLI应用实践方案
  • 移动端高清、多屏适配方案
  • ​Base64转换成图片,android studio build乱码,找不到okio.ByteString接腾讯人脸识别
  • (02)Hive SQL编译成MapReduce任务的过程
  • (04)odoo视图操作
  • (27)4.8 习题课
  • (aiohttp-asyncio-FFmpeg-Docker-SRS)实现异步摄像头转码服务器
  • (Forward) Music Player: From UI Proposal to Code
  • (MATLAB)第五章-矩阵运算
  • (pojstep1.3.1)1017(构造法模拟)
  • (react踩过的坑)Antd Select(设置了labelInValue)在FormItem中initialValue的问题
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (每日一问)计算机网络:浏览器输入一个地址到跳出网页这个过程中发生了哪些事情?(废话少说版)
  • (十二)Flink Table API
  • (一)十分简易快速 自己训练样本 opencv级联haar分类器 车牌识别
  • (中等) HDU 4370 0 or 1,建模+Dijkstra。
  • (转)用.Net的File控件上传文件的解决方案
  • (最优化理论与方法)第二章最优化所需基础知识-第三节:重要凸集举例
  • *(长期更新)软考网络工程师学习笔记——Section 22 无线局域网
  • *上位机的定义
  • ... fatal error LINK1120:1个无法解析的外部命令 的解决办法
  • ./mysql.server: 没有那个文件或目录_Linux下安装MySQL出现“ls: /var/lib/mysql/*.pid: 没有那个文件或目录”...
  • .java 9 找不到符号_java找不到符号