当前位置: 首页 > news >正文

YOLOv8改进 | 注意力机制 | 结合静态和动态上下文信息的注意力机制

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有50+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


上下文Transformer(CoT)块是一种新颖的Transformer风格模块,用于视觉识别。它充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而加强了视觉表示的能力。CoT块首先通过3×3卷积对输入键进行上下文化编码,得到输入的静态上下文表示。然后,将编码后的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态的多头注意力矩阵。最后,将静态和动态上下文表示的融合作为输出。CoT块可以轻松替换ResNet架构中的每个3×3卷积,产生一个名为上下文Transformer网络(CoTNet)的Transformer风格的主干网络。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改并将修改后的完整代码放在文章的最后方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1.原理

2. 将CoTAttention添加到YOLOv8中

2.1 CoTAttention代码实现

2.2 更改init.py文件

2.3 添加yaml文件

2.4 在task.py中进行注册

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6.总结


1.原理

论文地址:Contextual Transformer Networks for Visual Recognition——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

上下文 Transformer (CoT) 注意力是一种新颖的 Transformer 式模块,旨在增强视觉识别任务。以下是根据提供的文档对其主要原理的解释:

CoT 注意力的主要原理

1. 键的上下文编码

  • CoT 首先使用 3×3 卷积对输入键进行上下文编码。此步骤捕获输入特征图中本地邻居之间的静态上下文,从而产生静态上下文表示。

2. 动态注意力矩阵

  • 然后将上下文化的键与输入查询连接起来。此组合表示通过两个连续的 1×1 卷积来学习动态多头注意力矩阵。此步骤结合了查询-键关系和静态上下文以进行自注意力学习。

3. 动态上下文表示

  • 学习到的注意力矩阵用于加权输入值,从而产生动态上下文表示,从输入中捕获动态上下文。

4. 静态和动态上下文融合

  • 静态和动态上下文表示融合在一起,形成 CoT 块的最终输出。这种组合利用了通过自注意力学习到的局部邻域信息和更广泛的上下文。

优势和实现

  • 与 ResNet 集成

  • CoT 块可以替代 ResNet 架构中的 3×3 卷积,而无需增加参数数量或计算开销,从而创建了一个名为上下文 Transformer 网络 (CoTNet) 的新主干。

  • 性能提升

  • 与传统卷积网络和其他基于 Transformer 的架构相比,CoTNet 在各种任务(包括图像识别、对象检测和实例分割)中表现出色。

与传统自注意力的比较

  • 传统自注意力

  • 根据每个空间位置上的孤立查询键对来测量注意力,通常忽略相邻键之间的丰富上下文。

  • CoT Attention

  • 通过 3×3 卷积整合相邻键的静态上下文,并通过 1×1 卷积考虑组合查询和上下文化键来增强动态上下文学习。

视觉表示

  • 传统自注意力模块

  • 通常涉及使用查询和键之间的成对交互来计算注意力矩阵,而不考虑键之间的空间上下文。

  • CoT 模块

  • 涉及额外的 3×3 卷积步骤以进行键之间的上下文挖掘,然后进行动态注意力矩阵学习和上下文融合。

通过利用静态和动态上下文信息,CoT Attention 可以更全面地理解输入特征图,从而提高视觉识别能力。

2. 将CoTAttention添加到YOLOv8中

2.1 CoTAttention代码实现

关键步骤一: 将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/block.py中,并在该文件的__all__中添加“CoTAttention”

class CoTAttention(nn.Module):def __init__(self, dim=512, kernel_size=3):super().__init__()self.dim = dimself.kernel_size = kernel_sizeself.key_embed = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),nn.BatchNorm2d(dim),nn.SiLU())self.value_embed = nn.Sequential(nn.Conv2d(dim, dim, 1, bias=False),nn.BatchNorm2d(dim))factor = 4self.attention_embed = nn.Sequential(nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),nn.BatchNorm2d(2 * dim // factor),nn.SiLU(),nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1))def forward(self, x):bs, c, h, w = x.shapek1 = self.key_embed(x)  # bs,c,h,wv = self.value_embed(x).view(bs, c, -1)  # bs,c,h,wy = torch.cat([k1, x], dim=1)  # bs,2c,h,watt = self.attention_embed(y)  # bs,c*k*k,h,watt = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*wk2 = F.softmax(att, dim=-1) * vk2 = k2.view(bs, c, h, w)return k1 + k2

上下文转换器 (CoT) 注意力机制通过整合输入键之间的上下文信息来增强图像处理。以下是使用 CoT 注意力机制进行图像处理的主要工作流程的详细说明:

使用 CoT 注意力机制进行图像处理的主要工作流程

1. 输入特征图

  • 从大小为 (H \times W \times C) 的输入特征图 (X) 开始,其中 (H) 为高度,(W) 为宽度,(C) 为通道数。

2. 键的上下文编码

  • 对输入键应用 3×3 卷积以捕获本地邻居之间的静态上下文。这会产生一个表示上下文化键的新特征图: K_{contextual} = \text{Conv3x3}(X)

3. 与查询连接

  • 将上下文化键 (K{contextual}) 与输入查询 (Q) 连接起来。这种组合表示结合了原始输入和上下文信息: Q{concat} = \text{Concat}(Q, K_{contextual})

4. 动态注意矩阵学习

  • 将连接表示 (Q{concat}) 传递到两个连续的 1×1 卷积,以学习动态多头注意矩阵: A{dynamic} = \text{Conv1x1}(\text{Conv1x1}(Q_{concat}))

5. 动态上下文表示

  • 使用学习到的注意矩阵 (A{dynamic}) 加权输入值 (V),产生动态上下文表示。此步骤根据查询和键之间的关系捕获动态上下文:V{dynamic} = A_{dynamic} \cdot V

6. 静态和动态上下文融合

  • 将静态上下文表示 (K{contextual}) 与动态上下文表示 (V{dynamic}) 相结合以形成最终输出。此融合利用了局部和更广泛的上下文信息:\text{Output} = \text{Fuse}(K{contextual}, V{dynamic})

详细步骤

3×3 卷积用于上下文编码

  • 3×3 卷积扫描输入特征图以捕获相邻键之间的空间关系,从而创建反映局部依赖关系的静态上下文。

1×1 卷积用于注意力矩阵:

  • 两个连续的 1×1 卷积对连接的查询和上下文化键进行操作,以学习动态注意力矩阵,这有助于根据上下文相关性对输入值进行加权。

注意力机制:

  • CoT 中的注意力机制与传统的自注意力不同,它将静态上下文纳入动态注意力计算中,从而产生更强大、更能感知上下文的注意力矩阵。

融合机制:

  • 最后的融合步骤结合了静态和动态表示,确保模型既能从局部上下文(通过 3×3 卷积)中受益,也能从动态交互(通过学习注意力)中受益。

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_CoTA.yaml文件,粘贴下面的内容

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [ 0.33, 0.25, 1024 ]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [ -1, 1, Conv, [ 64, 3, 2 ] ]  # 0-P1/2- [ -1, 1, Conv, [ 128, 3, 2 ] ]  # 1-P2/4- [ -1, 3, C2f, [ 128, True ] ]- [ -1, 1, Conv, [ 256, 3, 2 ] ]  # 3-P3/8- [ -1, 6, C2f, [ 256, True ] ]- [ -1, 1, Conv, [ 512, 3, 2 ] ]  # 5-P4/16- [ -1, 6, C2f, [ 512, True ] ]- [ -1, 1, Conv, [ 1024, 3, 2 ] ]  # 7-P5/32- [ -1, 3, C2f, [ 1024, True ] ]- [ -1, 1, SPPF, [ 1024, 5 ] ]  # 9# YOLOv8.0n head
head:- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ [ -1, 6 ], 1, Concat, [ 1 ] ]  # cat backbone P4- [ -1, 3, C2f, [ 512 ] ]  # 12- [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]- [ [ -1, 4 ], 1, Concat, [ 1 ] ]  # cat backbone P3- [ -1, 3, C2f, [ 256 ] ]  # 15 (P3/8-small)- [ -1, 1, Conv, [ 256, 3, 2 ] ]- [ [ -1, 12 ], 1, Concat, [ 1 ] ]  # cat head P4- [ -1, 3, C2f, [ 512 ] ]  # 18 (P4/16-medium)- [ -1, 1, CoTAttention, [ 512 ] ]    # CoTAttention https://arxiv.org/abs/2107.12292- [ -1, 1, Conv, [ 512, 3, 2 ] ]- [ [ -1, 9 ], 1, Concat, [ 1 ] ]  # cat head P5- [ -1, 3, C2f, [ 1024 ] ]  # 21 (P5/32-large)- [ -1, 1, CoTAttention, [ 1024 ] ]- [ [ 15, 19, 23 ], 1, Detect, [ nc ] ]  # Detect(P3, P4, P5)

温馨提示:本文只是对yolov8基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv8n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
max_channels: 1024 # max_channels# YOLOv8s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
max_channels: 1024 # max_channels# YOLOv8l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
max_channels: 512 # max_channels# YOLOv8m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
max_channels: 768 # max_channels# YOLOv8x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
max_channels: 512 # max_channels

2.4 在task.py中进行注册

关键步骤四:在task.py的parse_model函数中进行注册,

elif m is CoTAttention:c1, c2 = ch[f], args[0]if c2 != nc:c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, *args[1:]]

2.5 执行程序

关键步骤五:在ultralytics文件中新建train.py,将model的参数路径设置为yolov8_CoTA.yaml的路径即可

from ultralytics import YOLO# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_CoTA.yaml')  # build from YAML and transfer weights# Train the model
model.train( batch=16)

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1wyZnAKNVMhMhuaH2lDJ7jA?pwd=yzye 

 提取码:yzye 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

img

改进后的GFLOPs

5. 进阶

可以结合损失函数或者卷积模块进行多重改进

6.总结

上下文变换注意 (CoTAttention) 是一种新颖的机制,旨在通过整合静态和动态上下文信息来增强视觉识别任务。它首先对输入键应用 3×3 卷积,以捕获本地邻居之间的静态上下文。然后将上下文化的键与输入查询连接起来,并将此组合表示通过两个连续的 1×1 卷积来学习动态多头注意矩阵。此矩阵用于加权输入值,从而产生动态上下文表示。最后,将静态和动态上下文表示融合以形成最终输出。此过程使 CoTAttention 能够利用通过自注意力学习到的局部邻域信息和更广泛的上下文,从而提高图像识别、对象检测和实例分割任务的性能。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 2024年6月份找工作和面试总结
  • RabbitMQ 更改服务端口号
  • 力扣1895.最大的幻方
  • 51单片机嵌入式开发:3、STC89C52操作8八段式数码管原理
  • NativeMemoryTracking查看java内存信息
  • udp发送数据如果超过1个mtu时,抓包所遇到的问题记录说明
  • 9 redis,memcached,nginx网络组件
  • 单/多线程--协程--异步爬虫
  • 洛谷 P2141 [NOIP2014 普及组] 珠心算测验
  • Harris点云关键点检测
  • 三、docker配置阿里云镜像仓库并配置docker代理
  • 使用瀚高数据库开发管理工具进行数据的备份与恢复---国产瀚高数据库工作笔记008
  • 使用Python绘制堆积柱形图
  • Ubuntu22.04使用/etc/rc.local开机启动程序
  • Stable Diffusion:最全详细图解
  • 时间复杂度分析经典问题——最大子序列和
  • $translatePartialLoader加载失败及解决方式
  • (三)从jvm层面了解线程的启动和停止
  • [js高手之路]搞清楚面向对象,必须要理解对象在创建过程中的内存表示
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • 【402天】跃迁之路——程序员高效学习方法论探索系列(实验阶段159-2018.03.14)...
  • 【翻译】Mashape是如何管理15000个API和微服务的(三)
  • JS 面试题总结
  • Python学习之路13-记分
  • Quartz初级教程
  • Redis 中的布隆过滤器
  • Sequelize 中文文档 v4 - Getting started - 入门
  • Vim Clutch | 面向脚踏板编程……
  • 服务器之间,相同帐号,实现免密钥登录
  • 力扣(LeetCode)965
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 容器化应用: 在阿里云搭建多节点 Openshift 集群
  • 使用putty远程连接linux
  • 微信开源mars源码分析1—上层samples分析
  • 移动端 h5开发相关内容总结(三)
  • 用Visual Studio开发以太坊智能合约
  • 与 ConTeXt MkIV 官方文档的接驳
  • - 语言经验 - 《c++的高性能内存管理库tcmalloc和jemalloc》
  • ​520就是要宠粉,你的心头书我买单
  • %3cscript放入php,跟bWAPP学WEB安全(PHP代码)--XSS跨站脚本攻击
  • (1)SpringCloud 整合Python
  • (3)(3.5) 遥测无线电区域条例
  • (aiohttp-asyncio-FFmpeg-Docker-SRS)实现异步摄像头转码服务器
  • (function(){})()的分步解析
  • (js)循环条件满足时终止循环
  • (NSDate) 时间 (time )比较
  • (第8天)保姆级 PL/SQL Developer 安装与配置
  • (二)构建dubbo分布式平台-平台功能导图
  • (接上一篇)前端弄一个变量实现点击次数在前端页面实时更新
  • (六)DockerCompose安装与配置
  • (七)Activiti-modeler中文支持
  • (四)进入MySQL 【事务】
  • (转)拼包函数及网络封包的异常处理(含代码)
  • (自适应手机端)响应式服装服饰外贸企业网站模板
  • . Flume面试题