当前位置: 首页 > news >正文

YOLOv9改进策略【模型轻量化】| ShufflenetV2,通过通道划分构建高效网络

一、本文介绍

本文记录的是基于ShufflenetV2的YOLOv9目标检测轻量化改进方法研究FLOPs是评价模型复杂独的重要指标,但其无法考虑到模型的内存访问成本和并行度,因此本文在YOLOv9的基础上引入ShufflenetV2使其在在保持准确性的同时提高模型的运行效率

模型参数量计算量推理速度(bs=32)
YOLOv9-c50.69M236.6GFLOPs32.1ms
Improved42.88M194.5GFLOPs23.2ms

文章目录

  • 一、本文介绍
  • 二、ShuffleNet V2设计原理
  • 三、ShuffleNet V2基础模块的实现代码
  • 四、添加步骤
    • 4.1 修改common.py
    • 4.2 修改yolo.py
  • 五、yaml模型文件
    • 5.1 模型改进⭐
  • 六、成功运行结果


二、ShuffleNet V2设计原理

ShuffleNet V2是一种高效的卷积神经网络架构,其模型结构及优势如下:

  1. 模型结构
    • 回顾ShuffleNet v1ShuffleNet是一种广泛应用于低端设备的先进网络架构,为增加在给定计算预算下的特征通道数量,采用了点组卷积和瓶颈结构,但这增加了内存访问成本(MAC),且过多的组卷积和元素级“Add”操作也存在问题。
    • 引入Channel Split和ShuffleNet V2:为解决上述问题,引入了名为Channel Split的简单操作。在每个单元开始时,将 c c c个特征通道的输入分为两个分支,分别具有 c − c ′ c - c' cc c ′ c' c个通道。一个分支保持不变,另一个分支由三个具有相同输入和输出通道的卷积组成,以满足G1(平衡卷积,即相等的通道宽度可最小化MAC)。两个 1 × 1 1 \times 1 1×1卷积不再是组式的,这部分是为了遵循G2(避免过多的组卷积增加MAC),部分是因为拆分操作已经产生了两个组。卷积后,两个分支连接,通道数量保持不变,并使用与ShuffleNet v1相同的“通道洗牌”操作来实现信息通信。对于空间下采样,单元进行了略微修改,删除了通道拆分操作,使输出通道数量加倍。
    • 整体网络结构:通过反复堆叠构建块来构建整个网络,设置 c ′ = c / 2 c' = c/2 c=c/2,整体网络结构与ShuffleNet v1相似,并在全局平均池化之前添加了一个额外的 1 × 1 1 \times 1 1×1卷积层来混合特征。
  2. 优势
    • 高效且准确:遵循了高效网络设计的所有准则,每个构建块的高效率使其能够使用更多的特征通道和更大的网络容量,并且在每个块中,一半的特征通道直接通过块并加入下一个块,实现了一种特征重用模式,类似于DenseNet,但更高效。
    • 速度优势明显:在与其他网络架构的比较中,ShuffleNet v2在速度方面表现出色,特别是在GPU上明显快于其他网络(如MobileNet v2、ShuffleNet v1和Xception)。在ARM上,ShuffleNet v1、Xception和ShuffleNet v2的速度相当,但MobileNet v2较慢,这是因为MobileNet v2的MAC较高。
    • 兼容性好:可以与其他技术(如Squeeze - and - excitation模块)结合进一步提高性能。

论文:https://arxiv.org/pdf/1807.11164.pdf
源码:https://gitcode.com/gh_mirrors/sh/ShuffleNet-Series/blob/master/ShuffleNetV2/blocks.py?utm_source=csdn_github_accelerator&isLogin=1

三、ShuffleNet V2基础模块的实现代码

ShuffleNet V2基础模块的实现代码如下:

def channel_shuffle(x, groups):batchsize, num_channels, height, width = x.data.size()channels_per_group = num_channels // groups# reshapex = x.view(batchsize, groups,channels_per_group, height, width)x = torch.transpose(x, 1, 2).contiguous()# flattenx = x.view(batchsize, -1, height, width)return xclass conv_bn_relu_maxpool(nn.Module):def __init__(self, c1, c2):  # ch_in, ch_outsuper(conv_bn_relu_maxpool, self).__init__()self.conv = nn.Sequential(nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),nn.BatchNorm2d(c2),nn.ReLU(inplace=True),)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)def forward(self, x):return self.maxpool(self.conv(x))class Shuffle_Block(nn.Module):def __init__(self, inp, oup, stride):super(Shuffle_Block, self).__init__()if not (1 <= stride <= 3):raise ValueError('illegal stride value')self.stride = stridebranch_features = oup // 2assert (self.stride != 1) or (inp == branch_features << 1)if self.stride > 1:self.branch1 = nn.Sequential(self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(inp),nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)self.branch2 = nn.Sequential(nn.Conv2d(inp if (self.stride > 1) else branch_features,branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)def forward(self, x):if self.stride == 1:x1, x2 = x.chunk(2, dim=1)  # 按照维度1进行splitout = torch.cat((x1, self.branch2(x2)), dim=1)else:out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)out = channel_shuffle(out, 2)return out

四、添加步骤

4.1 修改common.py

此处需要修改的文件是models/common.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

此时需要将上方实现的代码添加到common.py中。

在这里插入图片描述

注意❗:在4.2小节中的yolo.py文件中需要声明的模块名称为:conv_bn_relu_maxpoolShuffle_Block

4.2 修改yolo.py

此处需要修改的文件是models/yolo.py

yolo.py用于函数调用,我们只需要将common.py中定义的新的模块名添加到parse_model函数下即可。

conv_bn_relu_maxpool模块以及Shuffle_Block模块添加后如下:

在这里插入图片描述


五、yaml模型文件

5.1 模型改进⭐

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov9-c.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-shufflenetv2.yaml

yolov9-c.yaml中的内容复制到yolov9-c-shufflenetv2.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 模型的修改方法是将YOLOv9的骨干网络替换成Shufflenet V2ShuffleNet V2 在设计上注重减少内存访问成本并提高并行度,这有助于在保持准确性的同时提高模型的运行效率。相比YOLOv9原骨干网络,ShuffleNet V2 具有更低的计算复杂度,能够在相同或更少的计算资源下完成推理,对于实时性要求较高的任务具有重要意义。

结构如下:

# YOLOv9# parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()# anchors
anchors: 3# YOLOv9 backbone
backbone:[[-1, 1, Silence, []],  # conv down[-1, 1, conv_bn_relu_maxpool, [64, 3, 2]],  # 1-P1/2# conv down[-1, 1, Shuffle_Block, [ 128, 2 ]],  # 2-P2/4[-1, 3, Shuffle_Block, [ 128, 1 ]],  # 3[-1, 1, Shuffle_Block, [ 256, 2 ]],  # 4-P4/16 [-1, 7, Shuffle_Block, [ 256, 1 ]],  # 5[-1, 1, Shuffle_Block, [ 512, 2 ]],  # 6-P4/16[-1, 3, Shuffle_Block, [ 512, 1 ]],  # 7]# YOLOv9 head
head:[# elan-spp block[-1, 1, SPPELAN, [512, 256]],  # 10# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 3], 1, Concat, [1]],  # cat backbone P3# elan-2 block[-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)# avg-conv-down merge[-1, 1, ADown, [256]],[[-1, 11], 1, Concat, [1]],  # cat head P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)# avg-conv-down merge[-1, 1, ADown, [512]],[[-1, 8], 1, Concat, [1]],  # cat head P5# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)# multi-level reversible auxiliary branch# routing[3, 1, CBLinear, [[256]]], # 23[5, 1, CBLinear, [[256, 512]]], # 24[7, 1, CBLinear, [[256, 512, 512]]], # 25# conv down[0, 1, Conv, [64, 3, 2]],  # 26-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 27-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28# avg-conv down fuse[-1, 1, ADown, [256]],  # 29-P3/8[[21, 22, 23, -1], 1, CBFuse, [[0, 0, 0]]], # 30  # elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31# avg-conv down fuse[-1, 1, ADown, [512]],  # 32-P4/16[[22, 23, -1], 1, CBFuse, [[1, 1]]], # 33 # elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34# avg-conv down fuse[-1, 1, ADown, [512]],  # 35-P5/32[[23, -1], 1, CBFuse, [[2]]], # 36# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37# detection head# detect[[29, 32, 35, 14, 17, 20], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)]

六、成功运行结果

分别打印网络模型可以看到Shuffle_Block已经加入到模型中,并可以进行训练了。

yolov9-c-shufflenetv2

                 from  n    params  module                                  arguments                     0                -1  1         0  models.common.Silence                   []                            1                -1  1      1856  models.common.conv_bn_relu_maxpool      [3, 64]                       2                -1  1     14080  models.common.Shuffle_Block             [64, 128, 2]                  3                -1  3     27456  models.common.Shuffle_Block             [128, 128, 1]                 4                -1  1     52736  models.common.Shuffle_Block             [128, 256, 2]                 5                -1  7    242816  models.common.Shuffle_Block             [256, 256, 1]                 6                -1  1    203776  models.common.Shuffle_Block             [256, 512, 2]                 7                -1  3    404736  models.common.Shuffle_Block             [512, 512, 1]                 8                -1  1    656896  models.common.SPPELAN                   [512, 512, 256]               9                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          10           [-1, 5]  1         0  models.common.Concat                    [1]                           11                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          13           [-1, 3]  1         0  models.common.Concat                    [1]                           14                -1  1    814336  models.common.RepNCSPELAN4              [640, 256, 256, 128, 1]       15                -1  1    164352  models.common.ADown                     [256, 256]                    16          [-1, 11]  1         0  models.common.Concat                    [1]                           17                -1  1   2988544  models.common.RepNCSPELAN4              [768, 512, 512, 256, 1]       18                -1  1    656384  models.common.ADown                     [512, 512]                    19           [-1, 8]  1         0  models.common.Concat                    [1]                           20                -1  1   3119616  models.common.RepNCSPELAN4              [1024, 512, 512, 256, 1]      21                 3  1     33024  models.common.CBLinear                  [128, [256]]                  22                 5  1    197376  models.common.CBLinear                  [256, [256, 512]]             23                 7  1    656640  models.common.CBLinear                  [512, [256, 512, 512]]        24                 0  1      1856  models.common.Conv                      [3, 64, 3, 2]                 25                -1  1     73984  models.common.Conv                      [64, 128, 3, 2]               26                -1  1    212864  models.common.RepNCSPELAN4              [128, 256, 128, 64, 1]        27                -1  1    164352  models.common.ADown                     [256, 256]                    28  [21, 22, 23, -1]  1         0  models.common.CBFuse                    [[0, 0, 0]]                   29                -1  1    847616  models.common.RepNCSPELAN4              [256, 512, 256, 128, 1]       30                -1  1    656384  models.common.ADown                     [512, 512]                    31      [22, 23, -1]  1         0  models.common.CBFuse                    [[1, 1]]                      32                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       33                -1  1    656384  models.common.ADown                     [512, 512]                    34          [23, -1]  1         0  models.common.CBFuse                    [[2]]                         35                -1  1   2857472  models.common.RepNCSPELAN4              [512, 512, 512, 256, 1]       36[29, 32, 35, 14, 17, 20]  1  21542822  models.yolo.DualDDetect                 [1, [512, 512, 512, 256, 512, 512]]
yolov9-c-shufflenetv2 summary: 870 layers, 43094374 parameters, 43094342 gradients, 195.9 GFLOPs

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 基于matlab的行人和车辆检测系统
  • 模型 ACT心理灵活六边形
  • 不同浏览器JS对数组末尾多余的逗号的处理
  • AUTOSAR_EXP_ARAComAPI的5章笔记(1)
  • 八皇后问题代码实现(java,递归)
  • 选科组合(入门)
  • 微信陷阱丨警惕“间谍网勾”的迷魂汤
  • nginx部署前端vue项目
  • Python | Leetcode Python题解之第387题字符串中的第一个唯一字符
  • Spring之配置类解析源码解析
  • [数据集][目标检测]课堂行行为检测数据集VOC+YOLO格式4065张12类别
  • Python中排序算法之插入排序
  • LeetCode - 12 整数转罗马数字
  • 快速了解Git 文件的四种状态及其操作指令、如何忽略文件
  • 【随手记】excel中的text函数使用
  • php的引用
  • [nginx文档翻译系列] 控制nginx
  • Java,console输出实时的转向GUI textbox
  • JavaScript服务器推送技术之 WebSocket
  • js中的正则表达式入门
  • python docx文档转html页面
  • vue.js框架原理浅析
  • Work@Alibaba 阿里巴巴的企业应用构建之路
  • 百度贴吧爬虫node+vue baidu_tieba_crawler
  • 从零搭建Koa2 Server
  • 码农张的Bug人生 - 初来乍到
  • 扫描识别控件Dynamic Web TWAIN v12.2发布,改进SSL证书
  • 用jQuery怎么做到前后端分离
  • 再次简单明了总结flex布局,一看就懂...
  • 做一名精致的JavaScripter 01:JavaScript简介
  • 曾刷新两项世界纪录,腾讯优图人脸检测算法 DSFD 正式开源 ...
  • ​LeetCode解法汇总2304. 网格中的最小路径代价
  • ​queue --- 一个同步的队列类​
  • ​马来语翻译中文去哪比较好?
  • ​虚拟化系列介绍(十)
  • #、%和$符号在OGNL表达式中经常出现
  • (1)(1.8) MSP(MultiWii 串行协议)(4.1 版)
  • (1综述)从零开始的嵌入式图像图像处理(PI+QT+OpenCV)实战演练
  • (2/2) 为了理解 UWP 的启动流程,我从零开始创建了一个 UWP 程序
  • (CPU/GPU)粒子继承贴图颜色发射
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第5节(封闭类和Final方法)
  • (十五)devops持续集成开发——jenkins流水线构建策略配置及触发器的使用
  • (算法)N皇后问题
  • (算法)大数的进制转换
  • (转)Android学习笔记 --- android任务栈和启动模式
  • (转)fock函数详解
  • (转)Windows2003安全设置/维护
  • (转)负载均衡,回话保持,cookie
  • (转)利用PHP的debug_backtrace函数,实现PHP文件权限管理、动态加载 【反射】...
  • .NET C# 配置 Options
  • .net framework profiles /.net framework 配置
  • .NET/C#⾯试题汇总系列:集合、异常、泛型、LINQ、委托、EF!(完整版)
  • .NET3.5下用Lambda简化跨线程访问窗体控件,避免繁复的delegate,Invoke(转)
  • /etc/apt/sources.list 和 /etc/apt/sources.list.d
  • [ 手记 ] 关于tomcat开机启动设置问题