当前位置: 首页 > news >正文

[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023@author: chengxf2
"""import torch 
from torch import nn
from torch import optimclass DNN(nn.Module):'''它是一个序列容器,是nn.Module的子类。 `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。'''def __init__(self):super(DNN, self).__init__()self.net = nn.Sequential(nn.Linear(in_features=28*28, out_features=500),nn.Sigmoid(),nn.Linear(in_features=500, out_features=10),nn.Sigmoid())def forward(self, input):output = self.net(input)return outputdef train():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = DNN()criteon = torch.nn.CrossEntropyLoss(reduction='mean')optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)batch_size= 5data = torch.rand((batch_size,28*28))epochs = 2target = torch.tensor([0,1,2,3,4])target = target.to(device)for epoch in range(epochs):yHat = model(data)loss = criteon(yHat, target)loss.backward()print("\n loss ",loss)optimizer.step()if __name__ == "__main__":train()

 


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

相关文章:

  • SpringSecurity入门
  • 支付宝单笔转账开发
  • 压力测试过程中内存溢出(堆溢出、栈溢出、持久代溢出)情况如何解决
  • 【LeetCode刷题笔记(12-1)】【Python】【有效的字母异位词】【排序/字符统计】【简单】
  • Tomcat 部署论坛
  • 【Go】基于GoFiber从零开始搭建一个GoWeb后台管理系统(四)用户管理、部门管理模块
  • 华为云Stack 8.X 流量模型分析(一)
  • 87 GB 模型种子,GPT-4 缩小版,超越ChatGPT3.5,多平台在线体验
  • 云原生之深入解析K8S 1.27新特性如何简化状态服务跨集群平滑迁移
  • 实验4.2 默认路由和浮动静态路由的配置
  • C#监听端口报错“以一种访问权限不允许的方式做了访问套接字的尝试”
  • 【网络安全】-Linux操作系统—CentOS安装、配置
  • Flink系列之:Table API Connectors之Debezium
  • Apache Doris 在奇富科技的统一 OLAP 场景探索实践
  • MATLAB 点云中心化 (40)
  • iOS | NSProxy
  • JS变量作用域
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • JS进阶 - JS 、JS-Web-API与DOM、BOM
  • Kibana配置logstash,报表一体化
  • Linux学习笔记6-使用fdisk进行磁盘管理
  • PyCharm搭建GO开发环境(GO语言学习第1课)
  • SegmentFault 技术周刊 Vol.27 - Git 学习宝典:程序员走江湖必备
  • SpingCloudBus整合RabbitMQ
  • web标准化(下)
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 关于Flux,Vuex,Redux的思考
  • 将回调地狱按在地上摩擦的Promise
  • 前端_面试
  • 树莓派 - 使用须知
  • 推荐一个React的管理后台框架
  • 栈实现走出迷宫(C++)
  • 追踪解析 FutureTask 源码
  • ​​​​​​​​​​​​​​汽车网络信息安全分析方法论
  • ​LeetCode解法汇总518. 零钱兑换 II
  • #100天计划# 2013年9月29日
  • #每日一题合集#牛客JZ23-JZ33
  • #预处理和函数的对比以及条件编译
  • $.each()与$(selector).each()
  • (JS基础)String 类型
  • (react踩过的坑)Antd Select(设置了labelInValue)在FormItem中initialValue的问题
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (规划)24届春招和25届暑假实习路线准备规划
  • (理论篇)httpmoudle和httphandler一览
  • (欧拉)openEuler系统添加网卡文件配置流程、(欧拉)openEuler系统手动配置ipv6地址流程、(欧拉)openEuler系统网络管理说明
  • (强烈推荐)移动端音视频从零到上手(上)
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (十三)Java springcloud B2B2C o2o多用户商城 springcloud架构 - SSO单点登录之OAuth2.0 根据token获取用户信息(4)...
  • (转)mysql使用Navicat 导出和导入数据库
  • *_zh_CN.properties 国际化资源文件 struts 防乱码等
  • .apk 成为历史!
  • .NET 8 中引入新的 IHostedLifecycleService 接口 实现定时任务
  • .NET CF命令行调试器MDbg入门(三) 进程控制
  • .Net Core与存储过程(一)
  • .Net MVC + EF搭建学生管理系统