当前位置: 首页 > news >正文

Pytorch获取中间变量的梯度grad

      为了节约显存,pytorch在反向传播的过程中只保留了计算图中的叶子结点的梯度值,而未保留中间节点的梯度

import torch
 
x = torch.tensor(3., requires_grad=True)
y = x ** 2
z = 4 * y
 
z.backward()
print(x.grad)   # tensor(24.)
print(y.grad)   # None

可以看到当进行反向传播后,只保留了x的梯度tensor(24.),而y的梯度没有保留所以为None。

但有时我们需要得到模型中间变量的梯度(如绘制Grad-CAM图时),怎么办呢。

有2种方法

①torch.autograd.grad(outputs, inputs)

import torch
import torch.autograd as autograd
 
x = torch.tensor(3., requires_grad=True)
y = x ** 2
z = 4 * y
 
x_grad = autograd.grad(z, x, retain_graph=True)[0]
y_grad = autograd.grad(z, y, retain_graph=True)[0]
print(x_grad)   # tensor(24.)
print(y_grad)   # tensor(4.)

可以看到此时x和y的梯度都可以获得,使用此方法时不用执行.backward()。

②torch.Tensor.register_hook()

import torch
 
x = torch.tensor(3., requires_grad=True)
y = x ** 2
z = 4 * y
features_grad = 0.
 
 
# 为了读取模型中间参数变量的梯度而定义的辅助函数
def extract(g):
    global features_grad
    features_grad = g
 
 
y.register_hook(extract)
z.backward()
y_grad = features_grad
 
print(x.grad)   # tensor(24.)
print(y_grad)   # tensor(4.)

在执行反向传播之前,对需要求梯度的中间变量执行.register_hook(),便可获得该中间变量的梯度值。

相关文章:

  • Pytorch梯度裁剪 nn.utils.clip_grad_norm_()
  • Layer Normalization(LN) 层标准化
  • TF_CPP_MIN_LOG_LEVEL
  • Python sys.argv
  • pytorch模型可复现设置(cudnn.benchmark 加速卷积运算 cudnn.deterministic)
  • Python sys.stdout
  • Python vars()函数
  • Python类的self
  • Python输出numpy array带逗号和不带逗号
  • center loss 中心损失
  • torch与lua的关系
  • Python类super(super().__init__())
  • 自回归模型(Autoregressive model)(auto)
  • Pytorch tensorboard与tensorboardX的区别
  • Pytorch中的BN和IN(affine仿射, track_running_stats)
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • 2018以太坊智能合约编程语言solidity的最佳IDEs
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • Bootstrap JS插件Alert源码分析
  • docker python 配置
  • iBatis和MyBatis在使用ResultMap对应关系时的区别
  • Laravel Mix运行时关于es2015报错解决方案
  • LeetCode18.四数之和 JavaScript
  • leetcode378. Kth Smallest Element in a Sorted Matrix
  • Linux快速配置 VIM 实现语法高亮 补全 缩进等功能
  • MySQL Access denied for user 'root'@'localhost' 解决方法
  • Node 版本管理
  • node 版本过低
  • Web Storage相关
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • 程序员最讨厌的9句话,你可有补充?
  • 将 Measurements 和 Units 应用到物理学
  • 解决jsp引用其他项目时出现的 cannot be resolved to a type错误
  • 解析 Webpack中import、require、按需加载的执行过程
  • 开发了一款写作软件(OSX,Windows),附带Electron开发指南
  • Redis4.x新特性 -- 萌萌的MEMORY DOCTOR
  • 好程序员大数据教程Hadoop全分布安装(非HA)
  • 数据可视化之下发图实践
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ​直流电和交流电有什么区别为什么这个时候又要变成直流电呢?交流转换到直流(整流器)直流变交流(逆变器)​
  • ###C语言程序设计-----C语言学习(3)#
  • %3cli%3e连接html页面,html+canvas实现屏幕截取
  • (AngularJS)Angular 控制器之间通信初探
  • (附源码)ssm智慧社区管理系统 毕业设计 101635
  • (算法)求1到1亿间的质数或素数
  • (一)Mocha源码阅读: 项目结构及命令行启动
  • (一)为什么要选择C++
  • (转)ABI是什么
  • .apk文件,IIS不支持下载解决
  • .bat批处理出现中文乱码的情况
  • .class文件转换.java_从一个class文件深入理解Java字节码结构
  • .NET Remoting学习笔记(三)信道
  • .NET 应用启用与禁用自动生成绑定重定向 (bindingRedirect),解决不同版本 dll 的依赖问题
  • .NET 中 GetHashCode 的哈希值有多大概率会相同(哈希碰撞)