当前位置: 首页 > news >正文

爆改YOLOv8|使用MobileNetV3替换Backbone

1, 本文介绍

MobileNetV3 是 Google 团队于 2019 年提出的,相关论文为《Searching for MobileNetV3》。相较于 MobileNetV2,MobileNetV3 在性能上有显著提升:在 ImageNet 分类任务中,模型的准确率提高了 3.2%,同时计算延时减少了 20%。这些改进得益于对 Block (bneck) 的更新、使用了神经架构搜索 (NAS) 技术来优化参数以及重新设计了耗时层的结构。

关于MobileNetV3的详细介绍可以看论文:https://arxiv.org/pdf/1905.02244v5.pdf

本文将讲解如何将MobileNetV3融合进yolov8

话不多说,上代码!

2, 将MobileNetV3融合进yolov8

2.1 步骤一

首先找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个MobileNetV3.py文件,文件名字可以根据你自己的习惯起,然后将MobileNetV3的核心代码复制进去。

"""A from-scratch implementation of MobileNetV3 paper ( for educational purposes ).
PaperSearching for MobileNetV3 - https://arxiv.org/abs/1905.02244v5
author : shubham.aiengineer@gmail.com
"""import torch
from torch import nn
from torchsummary import summaryclass SqueezeExitationBlock(nn.Module):def __init__(self, in_channels: int):"""Constructor for SqueezeExitationBlock.Args:in_channels (int): Number of input channels."""super().__init__()self.pool1 = nn.AdaptiveAvgPool2d(1)self.linear1 = nn.Linear(in_channels, in_channels // 4)  # divide by 4 is mentioned in the paper, 5.3. Large squeeze-and-exciteself.act1 = nn.ReLU()self.linear2 = nn.Linear(in_channels // 4, in_channels)self.act2 = nn.Hardsigmoid()def forward(self, x):"""Forward pass for SqueezeExitationBlock."""identity = xx = self.pool1(x)x = torch.flatten(x, 1)x = self.linear1(x)x = self.act1(x)x = self.linear2(x)x = self.act2(x)x = identity * x[:, :, None, None]return xclass ConvNormActivationBlock(nn.Module):def __init__(self,in_channels: int,out_channels: int,kernel_size: list,stride: int = 1,padding: int = 0,groups: int = 1,bias: bool = False,activation: torch.nn = nn.Hardswish,):"""Constructs a block containing a convolution, batch normalization and activation layerArgs:in_channels (int): number of input channelsout_channels (int): number of output channelskernel_size (list): size of the convolutional kernelstride (int, optional): stride of the convolutional kernel. Defaults to 1.padding (int, optional): padding of the convolutional kernel. Defaults to 0.groups (int, optional): number of groups for depthwise seperable convolution. Defaults to 1.bias (bool, optional): whether to use bias. Defaults to False.activation (torch.nn, optional): activation function. Defaults to nn.Hardswish."""super().__init__()self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,groups=groups,bias=bias,)self.norm = nn.BatchNorm2d(out_channels)self.activation = activation()def forward(self, x):"""Perform forward pass."""x = self.conv(x)x = self.norm(x)x = self.activation(x)return xclass InverseResidualBlock(nn.Module):def __init__(self,in_channels: int,out_channels: int,kernel_size: int,expansion_size: int = 6,stride: int = 1,squeeze_exitation: bool = True,activation: nn.Module = nn.Hardswish,):"""Constructs a inverse residual blockArgs:in_channels (int): number of input channelsout_channels (int): number of output channelskernel_size (int): size of the convolutional kernelexpansion_size (int, optional): size of the expansion factor. Defaults to 6.stride (int, optional): stride of the convolutional kernel. Defaults to 1.squeeze_exitation (bool, optional): whether to add squeeze and exitation block or not. Defaults to True.activation (nn.Module, optional): activation function. Defaults to nn.Hardswish."""super().__init__()self.residual = in_channels == out_channels and stride == 1self.squeeze_exitation = squeeze_exitationself.conv1 = (ConvNormActivationBlock(in_channels, expansion_size, (1, 1), activation=activation)if in_channels != expansion_sizeelse nn.Identity())  # If it's not the first layer, then we need to add a 1x1 convolutional layer to expand the number of channelsself.depthwise_conv = ConvNormActivationBlock(expansion_size,expansion_size,(kernel_size, kernel_size),stride=stride,padding=kernel_size // 2,groups=expansion_size,activation=activation,)if self.squeeze_exitation:self.se = SqueezeExitationBlock(expansion_size)self.conv2 = nn.Conv2d(expansion_size, out_channels, (1, 1), bias=False)  # bias is false because we are using batch normalization, which already has biasself.norm = nn.BatchNorm2d(out_channels)def forward(self, x):"""Perform forward pass."""identity = xx = self.conv1(x)x = self.depthwise_conv(x)if self.squeeze_exitation:x = self.se(x)x = self.conv2(x)x = self.norm(x)if self.residual:x = x + identityreturn xclass MobileNetV3(nn.Module):def __init__(self,n_classes: int = 1000,input_channel: int = 3,config: str = "large",dropout: float = 0.8,):"""Constructs MobileNetV3 architectureArgs:`n_classes`: An integer count of output neuron in last layer, default 1000`input_channel`: An integer value input channels in first conv layer, default is 3.`config`: A string value indicating the configuration of MobileNetV3, either `large` or `small`, default is `large`.`dropout` [0, 1] : A float parameter for dropout in last layer, between 0 and 1, default is 0.8."""super().__init__()# The configuration of MobileNetv3.# input channels, kernel size, expension size, output channels, squeeze exitation, activation, strideRE = nn.ReLUHS = nn.Hardswishconfigs_dict = {"small": ((16, 3, 16, 16, True, RE, 2),(16, 3, 72, 24, False, RE, 2),(24, 3, 88, 24, False, RE, 1),(24, 5, 96, 40, True, HS, 2),(40, 5, 240, 40, True, HS, 1),(40, 5, 240, 40, True, HS, 1),(40, 5, 120, 48, True, HS, 1),(48, 5, 144, 48, True, HS, 1),(48, 5, 288, 96, True, HS, 2),(96, 5, 576, 96, True, HS, 1),(96, 5, 576, 96, True, HS, 1),),"large": ((16, 3, 16, 16, False, RE, 1),(16, 3, 64, 24, False, RE, 2),(24, 3, 72, 24, False, RE, 1),(24, 5, 72, 40, True, RE, 2),(40, 5, 120, 40, True, RE, 1),(40, 5, 120, 40, True, RE, 1),(40, 3, 240, 80, False, HS, 2),(80, 3, 200, 80, False, HS, 1),(80, 3, 184, 80, False, HS, 1),(80, 3, 184, 80, False, HS, 1),(80, 3, 480, 112, True, HS, 1),(112, 3, 672, 112, True, HS, 1),(112, 5, 672, 160, True, HS, 2),(160, 5, 960, 160, True, HS, 1),(160, 5, 960, 160, True, HS, 1),),}self.model = nn.Sequential(ConvNormActivationBlock(input_channel, 16, (3, 3), stride=2, padding=1, activation=nn.Hardswish),)for (in_channels,kernel_size,expansion_size,out_channels,squeeze_exitation,activation,stride,) in configs_dict[config]:self.model.append(InverseResidualBlock(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,expansion_size=expansion_size,stride=stride,squeeze_exitation=squeeze_exitation,activation=activation,))hidden_channels = 576 if config == "small" else 960_out_channel = 1024 if config == "small" else 1280self.model.append(ConvNormActivationBlock(out_channels,hidden_channels,(1, 1),bias=False,activation=nn.Hardswish,))if config == 'small':self.index = [16, 24, 48, 576]else:self.index = [24, 40, 112, 960]self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]def forward(self, x):"""Perform forward pass."""results = [None, None, None, None]for model in self.model:x = model(x)if x.size(1) in self.index:position = self.index.index(x.size(1))  # Find the position in the index listresults[position] = x# results.append(x)return resultsif __name__ == "__main__":# Generating Sample imageimage_size = (1, 3, 640, 640)image = torch.rand(*image_size)# Modelmobilenet_v3 = MobileNetV3(config="large")# summary(#     mobilenet_v3,#     input_data=image,#     col_names=["input_size", "output_size", "num_params"],#     device="cpu",#     depth=2,# )out = mobilenet_v3(image)print(out)

2.2 步骤二

在task.py导入我们的模块

2.3 步骤三

如下图标注框所示,添加两行代码

2.4 步骤四

在task.py如下图所示位置,添加标注框内所示代码

2.5 步骤五

在task.py如下图所示位置,添加标注框内所示代码

2.6 步骤六

在task.py如下图所示位置的代码需要替换

替换为下图所示代码

代码

        if verbose:LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}')  # printsave.extend(x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []if isinstance(c2, list):ch.extend(c2)if len(c2) != 5:ch.insert(0, 0)else:ch.append(c2)

2.7 步骤七

这次修改在base_model的predict_once方法里面,在task.py的前面部分代码中。

在task.py如下图所示位置的代码需要替换

替换为下图所示代码

代码如下,复制使用

  def _predict_once(self, x, profile=False, visualize=False, embed=None):y, dt, embeddings = [], [], []  # outputsfor m in self.model:if m.f != -1:  # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layersif profile:self._profile_one_layer(m, x, dt)if hasattr(m, 'backbone'):x = m(x)if len(x) != 5:  # 0 - 5x.insert(0, None)for index, i in enumerate(x):if index in self.save:y.append(i)else:y.append(None)x = x[-1]  # 最后一个输出传给下一层else:x = m(x)  # runy.append(x if m.i in self.save else None)  # save outputif visualize:feature_visualization(x, m.type, m.i, save_dir=visualize)if embed and m.i in embed:embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flattenif m.i == max(embed):return torch.unbind(torch.cat(embeddings, 1), dim=0)return x

2.8 步骤八

将下图所示代码注释掉,在ultralytics/utils/torch_utils.py中

修改为下图所示

到这里完成修改,但是这里面细节很多,大家一定要注意,仔细修改,步骤比较多,出现错误很难找出来

复制下面的yaml文件运行即可

yaml文件


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, MobileNetV3, []]  # 4- [-1, 1, SPPF, [1024, 5]]  # 5# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6- [[-1, 3], 1, Concat, [1]]  # 7 cat backbone P4- [-1, 3, C2f, [512]]  # 8- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9- [[-1, 2], 1, Concat, [1]]  # 10 cat backbone P3- [-1, 3, C2f, [256]]  # 11 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]] # 12- [[-1, 8], 1, Concat, [1]]  # 13 cat head P4- [-1, 3, C2f, [512]]  # 14 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]] # 15- [[-1, 5], 1, Concat, [1]]  # 16 cat head P5- [-1, 3, C2f, [1024]]  # 17 (P5/32-large)- [[11, 14, 17], 1, Detect, [nc]]  # Detect(P3, P4, P5)

# 今天这个修改的地方比较多,大家一定要仔细检查

不知不觉已经看完了哦,动动小手留个点赞吧--_--

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • leetcode 169 多数元素
  • 如何用AP525 测试输入信号的相位,频响,延时,Pop和卡顿
  • 2.Easy-Paas部署
  • 设计模式2个黄鹂鸣翠柳-《分析模式》漫谈23
  • strace 简介和使用
  • 图解 Elasticsearch 的 Fielddata Cache 使用与优化
  • 【HuggingFace Transformers】BertIntermediate 和 BertPooler源码解析
  • 使用 OpenCV 组合和缩放多张图像
  • 【网络基础】DNS协议详解:从背景到解析过程及`dig`工具的使用
  • Java核心概念之(线程、进程、同步、互斥)
  • (十二)Flink Table API
  • 给自闭症孩子家长的建议:携手同行,共筑爱的桥梁
  • Docker常见命令和参数
  • Cmake相关概念
  • HikariCP源码分析之源码环境搭建
  • [译] 理解数组在 PHP 内部的实现(给PHP开发者的PHP源码-第四部分)
  • 【翻译】babel对TC39装饰器草案的实现
  • CentOS学习笔记 - 12. Nginx搭建Centos7.5远程repo
  • css系列之关于字体的事
  • GDB 调试 Mysql 实战(三)优先队列排序算法中的行记录长度统计是怎么来的(上)...
  • Git初体验
  • js作用域和this的理解
  • magento2项目上线注意事项
  • MySQL常见的两种存储引擎:MyISAM与InnoDB的爱恨情仇
  • opencv python Meanshift 和 Camshift
  • webpack项目中使用grunt监听文件变动自动打包编译
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 浅谈JavaScript的面向对象和它的封装、继承、多态
  • 浅谈web中前端模板引擎的使用
  • 数组大概知多少
  • 算法---两个栈实现一个队列
  • 学习ES6 变量的解构赋值
  • ionic入门之数据绑定显示-1
  • MyCAT水平分库
  • Nginx惊现漏洞 百万网站面临“拖库”风险
  • 曜石科技宣布获得千万级天使轮投资,全方面布局电竞产业链 ...
  • ​​​​​​​STM32通过SPI硬件读写W25Q64
  • ​iOS安全加固方法及实现
  • #我与Java虚拟机的故事#连载08:书读百遍其义自见
  • (03)光刻——半导体电路的绘制
  • (C语言)fread与fwrite详解
  • (Mac上)使用Python进行matplotlib 画图时,中文显示不出来
  • (Redis使用系列) Springboot 实现Redis 同数据源动态切换db 八
  • (Redis使用系列) SpringBoot中Redis的RedisConfig 二
  • (层次遍历)104. 二叉树的最大深度
  • (二)基于wpr_simulation 的Ros机器人运动控制,gazebo仿真
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (附源码)springboot 校园学生兼职系统 毕业设计 742122
  • (一)Neo4j下载安装以及初次使用
  • (原创)攻击方式学习之(4) - 拒绝服务(DOS/DDOS/DRDOS)
  • (转)Android学习笔记 --- android任务栈和启动模式
  • (转)EXC_BREAKPOINT僵尸错误
  • (转)淘淘商城系列——使用Spring来管理Redis单机版和集群版
  • *p=a是把a的值赋给p,p=a是把a的地址赋给p。
  • .net core 客户端缓存、服务器端响应缓存、服务器内存缓存