当前位置: 首页 > news >正文

YOLOv8改进 | 细节创新篇 | iAFF迭代注意力特征融合助力多目标细节涨点

一、本文介绍

本文给大家带来的改进机制是iAFF(迭代注意力特征融合),其主要思想是通过改善特征融合过程来提高检测精度。传统的特征融合方法如加法或串联简单,未考虑到特定对象的融合适用性。iAFF通过引入多尺度通道注意力模块(我个人觉得这个改进机制就算融合了注意力机制的求和操作),更好地整合不同尺度和语义不一致的特征。该方法属于细节上的改进,并不影响任何其它的模块,非常适合大家进行融合改进,单独使用也是有一定的涨点效果。

推荐指数:⭐⭐⭐⭐

涨点效果:⭐⭐⭐⭐

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

训练结果对比图-> 

目录

一、本文介绍

二、iAFF的基本框架原理

三、iAFF的核心代码

四、手把手教你添加iAFF

4.1 iAFF添加步骤

4.1.1 步骤一

4.1.2 步骤二

4.1.3 步骤三

五、C2f_iAFF的yaml文件和运行记录

5.1 C2f_iAFF的yaml文件

5.2 C2f_iAFF的训练过程截图 

六、本文总结


二、iAFF的基本框架原理

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转


iAFF的主要思想在于通过更精细的注意力机制来改善特征融合,从而增强卷积神经网络。它不仅处理了由于尺度和语义不一致而引起的特征融合问题,还引入了多尺度通道注意力模块,提供了一种统一且通用的特征融合方案。此外,iAFF通过迭代注意力特征融合来解决特征图初始整合可能成为的瓶颈。这种方法使得模型即使在层数或参数较少的情况下,也能取得到较好的效果。 

iAFF的创新点主要包括:

1. 注意力特征融合:提出了一种新的特征融合方式,利用注意力机制来改善传统的简单特征融合方法(如加和或串联)。

2. 多尺度通道注意力模块:解决了在不同尺度上融合特征时出现的问题,特别是语义和尺度不一致的特征融合问题。

3. 迭代注意力特征融合(iAFF):通过迭代地应用注意力机制来改善特征图的初步整合,克服了初步整合可能成为性能瓶颈的问题。

​ 

这张图片是关于所提出的AFF(注意力特征融合)和iAFF(迭代注意力特征融合)的示意图。图中展示了两种结构:

(a) AFF: 展示了一个通过多尺度通道注意力模块(MS-CAM)来融合不同特征的基本框架。特征图X和Y通过MS-CAM和其他操作融合,产生输出Z。

(b) iAFF: 与AFF类似,但添加了迭代结构。在这里,输出Z回馈到输入,与X和Y一起再次经过MS-CAM和融合操作,以进一步细化特征融合过程。

(这两种方法都是文章中提出的我仅使用了iAFF也就是更复杂的版本,大家对于AFF有兴趣的可以按照我的该法进行相似添加即可)


三、iAFF的核心代码

该代码的使用方式需要两个图片,有人去用其替换Concat操作,但是它的两个输入必须是相同shape,但是YOLOv8中我们Concat一般两个输入在图像宽高上都不一样,所以我用其替换Bottlenekc中的残差相加操作,算是一种比较细节上的创新。

import torch
import torch.nn as nndef autopad(k, p=None, d=1):  # kernel, padding, dilation"""Pad to 'same' shape outputs."""if d > 1:k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-sizeif p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-padreturn pclass Conv(nn.Module):"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):"""Initialize Conv layer with given arguments including activation."""super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):"""Apply convolution, batch normalization and activation to input tensor."""return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):"""Perform transposed convolution of 2D data."""return self.act(self.conv(x))class iAFF(nn.Module):'''多特征融合 iAFF'''def __init__(self, channels=64, r=2):super(iAFF, self).__init__()inter_channels = int(channels // r)# 本地注意力self.local_att = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)# 全局注意力self.global_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),)# 第二次本地注意力self.local_att2 = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)# 第二次全局注意力self.global_att2 = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.sigmoid = nn.Sigmoid()def forward(self, x, residual):xa = x + residualxl = self.local_att(xa)xg = self.global_att(xa)xlg = xl + xgwei = self.sigmoid(xlg)xi = x * wei + residual * (1 - wei)xl2 = self.local_att2(xi)xg2 = self.global_att(xi)xlg2 = xl2 + xg2wei2 = self.sigmoid(xlg2)xo = x * wei2 + residual * (1 - wei2)return xoclass AFF(nn.Module):'''多特征融合 AFF'''def __init__(self, channels=64, r=4):super(AFF, self).__init__()inter_channels = int(channels // r)self.local_att = nn.Sequential(nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.global_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(inter_channels),nn.ReLU(inplace=True),nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(channels),)self.sigmoid = nn.Sigmoid()def forward(self, x, residual):xa = x + residualxl = self.local_att(xa)xg = self.global_att(xa)xlg = xl + xgwei = self.sigmoid(xlg)xo = 2 * x * wei + 2 * residual * (1 - wei)return xoclass C2f_iAFF(nn.Module):"""Faster Implementation of CSP Bottleneck with 2 convolutions."""def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,expansion."""super().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))def forward(self, x):"""Forward pass through C2f layer."""y = list(self.cv1(x).chunk(2, 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))def forward_split(self, x):"""Forward pass using split() instead of chunk()."""y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))class Bottleneck(nn.Module):"""Standard bottleneck."""def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, andexpansion."""super().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, k[0], 1)self.cv2 = Conv(c_, c2, k[1], 1, g=g)self.add = shortcut and c1 == c2self.iAFF = iAFF(c2)def forward(self, x):"""'forward()' applies the YOLO FPN to input data."""if self.add:results =  self.iAFF(x , self.cv2(self.cv1(x)))else:results = self.cv2(self.cv1(x))return resultsif __name__ == '__main__':x = torch.ones(8, 64, 32, 32)channels = x.shape[1]model = C2f_iAFF(channels, channels, True)output = model(x)print(output.shape)


四、手把手教你添加iAFF

4.1 iAFF添加步骤

4.1.1 步骤一

首先我们找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个py文件,名字可以根据你自己的习惯起,然后将iAFF的核心代码复制进去。

4.1.2 步骤二

之后我们找到'ultralytics/nn/tasks.py'文件,在其中注册我们的iAFF模块。

首先我们需要在文件的开头导入我们的iAFF模块,如下图所示->

4.1.3 步骤三

我们找到parse_model这个方法,可以用搜索也可以自己手动找,大概在六百多行吧。 我们找到如下的地方,然后将C2f_iAFF添加进去即可,模仿我添加即可。

到此我们就注册成功了,可以修改yaml文件中输入C2f_iAFF使用这个模块了。


五、C2f_iAFF的yaml文件和运行记录

5.1 C2f_iAFF的yaml文件

下面的添加C2f_iAFF是我实验结果的版本。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f_iAFF, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f_iAFF, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f_iAFF, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f_iAFF, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.2 C2f_iAFF的训练过程截图 

下面是添加了C2f_iAFF的训练截图。

大家可以看下面的运行结果和添加的位置所以不存在我发的代码不全或者运行不了的问题大家有问题也可以在评论区评论我看到都会为大家解答(我知道的)。


六、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

相关文章:

  • 关于IDEA中Git版本回滚整理
  • 爬虫工作量由小到大的思维转变---<第三十四章 Scrapy 的部署scrapyd+Gerapy>
  • Docker 数据持久化的三种方式
  • JS的this机制
  • 【面试题】写一个睡眠函数
  • leetcode-2.两数相加
  • 【数据倾斜笔记】
  • Pandas中concat的用法
  • JavaScript:正则表达式
  • 【线性代数】决定张成空间的最少向量线性无关吗?
  • uniapp+echarts开发APP版本教程
  • 5.2 显示窗口的内容(二)
  • JUnit 5和Mockito单元测试
  • css 用多个阴影做出光斑投影的效果 box-shadow
  • 学习笔记-MyBatis的工作原理。
  • 【前端学习】-粗谈选择器
  • Java读取Properties文件的六种方法
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 飞驰在Mesos的涡轮引擎上
  • 工作手记之html2canvas使用概述
  • 悄悄地说一个bug
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • 微信开放平台全网发布【失败】的几点排查方法
  • 主流的CSS水平和垂直居中技术大全
  • ​第20课 在Android Native开发中加入新的C++类
  • #【QT 5 调试软件后,发布相关:软件生成exe文件 + 文件打包】
  • #14vue3生成表单并跳转到外部地址的方式
  • ( 10 )MySQL中的外键
  • (31)对象的克隆
  • (Git) gitignore基础使用
  • (编程语言界的丐帮 C#).NET MD5 HASH 哈希 加密 与JAVA 互通
  • (二)构建dubbo分布式平台-平台功能导图
  • (考研湖科大教书匠计算机网络)第一章概述-第五节1:计算机网络体系结构之分层思想和举例
  • (力扣记录)1448. 统计二叉树中好节点的数目
  • (免费领源码)Java#Springboot#mysql农产品销售管理系统47627-计算机毕业设计项目选题推荐
  • (七)微服务分布式云架构spring cloud - common-service 项目构建过程
  • (使用vite搭建vue3项目(vite + vue3 + vue router + pinia + element plus))
  • (四)Controller接口控制器详解(三)
  • (学习日记)2024.04.04:UCOSIII第三十二节:计数信号量实验
  • (译)2019年前端性能优化清单 — 下篇
  • (转)visual stdio 书签功能介绍
  • .net Application的目录
  • .NET Core 将实体类转换为 SQL(ORM 映射)
  • @EnableConfigurationProperties注解使用
  • @media screen 针对不同移动设备
  • @RequestBody的使用
  • @SuppressWarnings注解
  • [AIGC] 如何建立和优化你的工作流?
  • [Contiki系列论文之2]WSN的自适应通信架构
  • [Docker]三.Docker 部署nginx,以及映射端口,挂载数据卷
  • [iOS]Win8下iTunes无法连接iPhone版本的解决方法
  • [MICROSAR Adaptive] --- Hello Adaptive World
  • [MT8766][Android12] 增加应用安装白名单或者黑名单
  • [MySQL FAQ]系列 -- 如何利用触发器实现账户权限审计
  • [NET].NET Framework 3.5 SP1 真正的离线安装(转)