当前位置: 首页 > news >正文

pytorch利用hook【钩子】获取torch网络每层结构【附代码】

写本文的目的是为了方便在剪枝中或其他应用中获取网络结构,如何有效的利用hook获取每层的结构来判断是否可以剪枝。

要对网络进行trace,或者获取网络结构,需要知道“grad_fn”。我们知道在pytorch中导数对应的关键词为“grad”。对一个变量我们可以设置requires_grad为True或者False来设置该变量是否求偏导。

目录

grad_fn

 hook获取网络结构


grad_fn

grad_fn: grad_fn用来记录变量变化的过程,方便计算梯度,比如:y = x*2,grad_fn记录了y由x计算的过程。

这里举个例子:设置一个x,并设置其可求导,就也是后面要对他求偏导。

x = torch.ones(2,2, requires_grad=True)
x
Out[4]: 
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

当我们输出x的grad_fn时输出为None,这是因为这里的tensor x是直接给出的,他并没有经过任何的运算。

print(x.grad_fn)
None

当我们设置一个简单的二次函数y=2 * x,可以得到如下结果,可以看到grad_fn现在显示的MulBackward0,意思就是用了乘法。

y = 2*x
Out[10]: 
tensor([[2., 2.],
        [2., 2.]], grad_fn=<MulBackward0>)

 hook获取网络结构

通过理解了grad_fn,那么我们就可以对网络进行trace,获取每层的网络结构了。

这里以YOLOv5s为例。先附上代码,PRUNABLE_MODULES列表放了Conv,BN以及PReLu。

grad_fn_to_module字典是用来通过grad_fn获取网络每层的结构,也就是如果grad_fn不为None的时候就放入字典中。

visited用来记录每层出现的次数。

这里会用到一个关键的函数:register_forward_hook。

该函数的作用是在不改变torch网络的情况下获取每层的输出。该方法需要传入一个func,其中包含module,inputs,outputs。也就是我下面代码中定义的_record_module_grad_fn。

import torch
import torch.nn as nn
PRUNABLE_MODULES = [ nn.modules.conv._ConvNd, nn.modules.batchnorm._BatchNorm, nn.Linear, nn.PReLU]
grad_fn_to_module = {}  # 如果获取不到是无法剪枝的
visited = {}  # visited会记录每层出现的次数
def _record_module_grad_fn(module, inputs, outputs): # 记录model的grad_fn
    if module not in visited:
        visited[module] = 1
    else:
        visited[module] += 1
    grad_fn_to_module[outputs.grad_fn] = module
model = torch.load('../runs/train/exp/weights/best.pt')['model'].float().cpu()
for para in model.parameters():
    para.requires_grad = True
x = torch.ones(1, 3, 640, 640)
for m in model.modules():
    if isinstance(m, tuple(PRUNABLE_MODULES)):
        hooks = [m.register_forward_hook(_record_module_grad_fn)]
out = model(x)
for hook in hooks:
    hook.remove()
print(grad_fn_to_module)

这里需要注意:在代码运行到out = model(x)之前的过程中,grad_fn_to_module字典一直为空。通过debug也可以看到。

 

 但是!!当我用样例x将我的mode跑了一遍获得out的时候,此刻grad_fn_to_module就开始将网络从头到尾开始记录了。该字典内容如下,可以看到针对第一个key为Convolution操作,所以记录下了Conv2d(3,32,.....)这一层,后面都是如此。

{<MkldnnConvolutionBackward object at 0x000001A38543D408>: Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2), bias=False), <NativeBatchNormBackward object at 0x000001A38543DD48>: BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x000001A38543D788>: Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), <NativeBatchNormBackward object at 0x000001A38543DD08>: BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x000001A385436688>: Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False), <NativeBatchNormBackward object at 0x000001A3854361C8>: BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x000001A385436388>: Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False), 

通过上面的方法我们就可以通过hook获取网络的每层结构。这个就可以用来做剪枝操作。

注意!在我的代码中可以看到有这么一行:

for para in model.parameters():
    para.requires_grad = True

我这里将模型的所有参数均设置为可导的,为什么要这里设置呢,这是因为我在对官方代码yolov5 6.0代码剪枝的时候,发现backbone无法剪枝,比如我想对第一层进行剪枝,会给我报KeyError的错误,最后通过仔细研究发现,在官方提供的v5模型中backbone的grad_fn均为None,利用hook无法获得网络,只能获得head部分的结构,下面显示是backbone的grad_fn为None记录的结构,:解决的办法也很简单,就是加入我上面的代码,并设置参数可导即可。

{None: BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x0000022594845C08>: Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), <NativeBatchNormBackward object at 0x00000225948455C8>: BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x0000022594845088>: Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), <NativeBatchNormBackward object at 0x0000022594845108>: BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True), <MkldnnConvolutionBackward object at 0x0000022594845248>: Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False), <NativeBatchNormBackward object at 0x0000022594845608>:  


你学会了吗~ 

         

 

相关文章:

  • 快速了解Nginx的基本介绍
  • 字符串统计:strlen函数的讲解,及其模拟实现
  • Linux——什么是环境变量?
  • 关于软件定时器的一些讨论
  • 睿智的目标检测60——Pytorch搭建YoloV7目标检测平台
  • Vue教程-监听路由ve-router变化,命名视图,路由嵌套,路由参数,路由高亮,router-link,redirect,创建路由,
  • 知识点杂记
  • 微信小程序入门与实战之rpx响应式单位与flex布局
  • @RequestMapping用法详解
  • 【MATLAB教程案例20】关于优化类算法的改进方向探索及matlab仿真对比分析
  • [ vulhub漏洞复现篇 ] Apache Flink目录遍历(CVE-2020-17519)
  • mysql的聚簇索引和非聚簇索引
  • 【React项目】从0搭建项目,项目准备和基础构建
  • markdown数学公式编辑指令大全
  • ContentProvider 之 监听共享数据变化
  • -------------------- 第二讲-------- 第一节------在此给出链表的基本操作
  • #Java异常处理
  • (十五)java多线程之并发集合ArrayBlockingQueue
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • 0x05 Python数据分析,Anaconda八斩刀
  • ABAP的include关键字,Java的import, C的include和C4C ABSL 的import比较
  • DataBase in Android
  • Java 内存分配及垃圾回收机制初探
  • Javascript 原型链
  • Java比较器对数组,集合排序
  • PAT A1092
  • Perseus-BERT——业内性能极致优化的BERT训练方案
  • PHP的类修饰符与访问修饰符
  • Promise初体验
  • python 学习笔记 - Queue Pipes,进程间通讯
  • SegmentFault 2015 Top Rank
  • 阿里云前端周刊 - 第 26 期
  • 多线程事务回滚
  • 技术胖1-4季视频复习— (看视频笔记)
  • 码农张的Bug人生 - 初来乍到
  • 如何设计一个比特币钱包服务
  • 入口文件开始,分析Vue源码实现
  • 删除表内多余的重复数据
  • 算法---两个栈实现一个队列
  • 延迟脚本的方式
  • - 转 Ext2.0 form使用实例
  • Play Store发现SimBad恶意软件,1.5亿Android用户成受害者 ...
  • 翻译 | The Principles of OOD 面向对象设计原则
  • (10)Linux冯诺依曼结构操作系统的再次理解
  • (33)STM32——485实验笔记
  • (ISPRS,2023)深度语义-视觉对齐用于zero-shot遥感图像场景分类
  • (八)Spring源码解析:Spring MVC
  • (附源码)spring boot儿童教育管理系统 毕业设计 281442
  • (附源码)ssm户外用品商城 毕业设计 112346
  • (附源码)计算机毕业设计高校学生选课系统
  • (论文阅读31/100)Stacked hourglass networks for human pose estimation
  • (亲测成功)在centos7.5上安装kvm,通过VNC远程连接并创建多台ubuntu虚拟机(ubuntu server版本)...
  • (实战)静默dbca安装创建数据库 --参数说明+举例
  • (新)网络工程师考点串讲与真题详解
  • (转)从零实现3D图像引擎:(8)参数化直线与3D平面函数库