当前位置: 首页 > news >正文

Pytorch学习——梯度下降和反向传播 03 未完

文章目录

  • 1 梯度是什么
  • 2 判断模型好坏的方法
  • 3 前向传播
  • 4 反向传播
  • 5 Pytorch中反向传播和梯度计算的方法
    • 5.1 前向计算
    • 5.2 梯度计算
    • 5.3 torch.data
    • 5.4 tensor.numpy

1 梯度是什么

通俗的来说就是学习(参数更新)的方向。
简单理解,(对于低维函数来讲)就是导数(或者是变化最快的方向)

2 判断模型好坏的方法

  1. 回归损失
    l o s s = ( Y p r e d i c t − Y t r u e ) 2 loss = (Y_{predict} - Y_{true})^2 loss=(YpredictYtrue)2

  2. 分类损失

l o s s = Y t r u e ⋅ l o g ( Y p r e d i c t ) loss = Y_{true} · log(Y_{predict}) loss=Ytruelog(Ypredict)

3 前向传播

J ( a , b , c ) = 3 ( a + b c ) J(a, b, c) = 3(a + bc) J(a,b,c)=3(a+bc), 令 u = a + v u = a+v u=a+v v = b c v = bc v=bc,把它绘制成计算图可以表示为:

在这里插入图片描述
绘制成计算图之后,可以清楚的看到前向计算的过程。

4 反向传播

对每个节点求偏导可以有:

在这里插入图片描述
反向传播就是一个从右到左的过程,自变量 a , b , c a,b,c a,b,c各自的骗到就是连线上梯度的乘积

在这里插入图片描述

5 Pytorch中反向传播和梯度计算的方法

5.1 前向计算

对于Pytorch中的一个tensor,如果设置它的属性, .require_grad=True ,那么会追踪对于该张量的所有操作。
默认值为None

import torch

x = torch.ones(2, 2, requires_grad=True)
print(x)

y = x+2
print(y)

z = y*y*3
print(z)

输出:

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)

总结:

(1)之后的每次计算都会修改其grad_fn属性,用来记录做过的操作;
(2)通过这个函数和grad_fn 可以生成计算图。


  • 注意

为了防止跟踪历史记录,可以将代码包装在with torch.no_grad 中。表示不需要追中这一块的计算。

import torch

x = torch.ones(2, 2, requires_grad=True)
print(x)

y = x+2
print(y)

z = y*y*3
print(z)

with torch.no_grad():
    u = x+y+z

print(u)
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
tensor([[31., 31.],
        [31., 31.]])

5.2 梯度计算

可以使用backward() 方法来进行反向传播,计算梯度 out.backward() ,此时能方便求出导数。

调用x.grad() 可以获取导数值。

注意:
在输入是一个标量的情况下,可以调用输出tensorbackward()方法,但是在输出是一个向量的时候,调用backward时要传入其他参数。


5.3 torch.data

当tensor的require_grad 为false的时候,a.data 等同于 a

当tensor的require_grad 为True的时候,a.data 表示仅仅获取其中的数据


5.4 tensor.numpy

require_grad=True 不能够直接转换,需要用torch.detach().numpy()

detach相当于是深拷贝
相当于把原来的tensor数据“抽离”出来,并部影响原来的tensor,然后进行深拷贝,转化为numpy数据。

相关文章:

  • 一次实战压测流程及问题梳理
  • HTTP协议中常见的状态码及其含义
  • Go 语言 设计模式-工厂模式
  • 塑化行业SRM供应商管理系统:缩短采购周期时间,改善供应商采购管理
  • 【原创】基于SpringBoot的灾情救助系统(疫情援助系统)(SpringBoot毕业设计)
  • EasyExcel 导入导出Excel文件
  • python基础语法二(函数、列表)
  • Shopee店铺提高商品转化的方法,你get到了吗
  • Java筑基32-IO流02-节点流处理流
  • 【ffmpeg】音频采集
  • 【负荷预测】基于蚂蚁优化算法的BP神经网络在负荷预测中的应用研究(Matlab完整代码实现)
  • 前端例程20220913:粒子飘落效果动画背景
  • 狂神的springboot课程员工管理系统
  • 散列表(哈希表)概述
  • Linux命令之sed批量替换字符串
  • [rust! #004] [译] Rust 的内置 Traits, 使用场景, 方式, 和原因
  • 《Java编程思想》读书笔记-对象导论
  • Docker 1.12实践:Docker Service、Stack与分布式应用捆绑包
  • ES6系列(二)变量的解构赋值
  • iOS小技巧之UIImagePickerController实现头像选择
  • Java 多线程编程之:notify 和 wait 用法
  • Netty源码解析1-Buffer
  • overflow: hidden IE7无效
  • vuex 笔记整理
  • Webpack 4 学习01(基础配置)
  • 关于Flux,Vuex,Redux的思考
  • 浏览器缓存机制分析
  • Linux权限管理(week1_day5)--技术流ken
  • 东超科技获得千万级Pre-A轮融资,投资方为中科创星 ...
  • #《AI中文版》V3 第 1 章 概述
  • $NOIp2018$劝退记
  • (二)构建dubbo分布式平台-平台功能导图
  • (翻译)Entity Framework技巧系列之七 - Tip 26 – 28
  • (附源码)ssm跨平台教学系统 毕业设计 280843
  • (附源码)基于ssm的模具配件账单管理系统 毕业设计 081848
  • (一)认识微服务
  • (一)为什么要选择C++
  • (转) 深度模型优化性能 调参
  • (转)ABI是什么
  • .Net IE10 _doPostBack 未定义
  • .NET MVC之AOP
  • .NET6使用MiniExcel根据数据源横向导出头部标题及数据
  • .NetCore实践篇:分布式监控Zipkin持久化之殇
  • .set 数据导入matlab,设置变量导入选项 - MATLAB setvaropts - MathWorks 中国
  • @NestedConfigurationProperty 注解用法
  • @SuppressWarnings(unchecked)代码的作用
  • [ Algorithm ] N次方算法 N Square 动态规划解决
  • [ 渗透测试面试篇 ] 渗透测试面试题大集合(详解)(十)RCE (远程代码/命令执行漏洞)相关面试题
  • [AAuto]给百宝箱增加娱乐功能
  • [AS3]URLLoader+URLRequest+JPGEncoder实现BitmapData图片数据保存
  • [CareerCup] 13.1 Print Last K Lines 打印最后K行
  • [cocos creator]EditBox,editing-return事件,清空输入框
  • [hdu 4405] Aeroplane chess [概率DP 期望]
  • [HEOI2013]ALO
  • [Java] 模拟Jdk 以及 CGLib 代理原理