当前位置: 首页 > news >正文

论文阅读YOLO-World: Real-Time Open-Vocabulary Object Detection

核心:

在这里插入图片描述

  • 开放词汇的实时的yolo检测器。
  • 重参数化的视觉语言聚合路径模块Re-parameterizable VisionLanguage Path Aggregation Network (RepVL-PAN)
  • 实时核心:轻量化的检测器+离线词汇推理过程重参数化

方法

在这里插入图片描述
预训练方案:将实例注释重新定义为区域-文本对,通过大规模检测、定位和图像-文本数据进行预训练。
模型架构:YOLO-World由YOLO检测器、文本编码器和RepVL-PAN组成,利用跨模态融合增强文本和图像表示

基础结构

  • Yolo detectorV8, darknet+PAN+head
  • Text Encoder. CLIP+n-gram
  • Text Contrastive Head.两个3x3回归bbox框以及object embedding。object embedding与文本embedding计算相似度求对比loss
  • Inference with Offline Vocabulary.prompt提前确定好,提前计算好embedding。再重参数化到PAN模块。
    在这里插入图片描述

3.3. Re-parameterizable Vision-Language PAN

在这里插入图片描述
RepVL-PAN由多尺度图像特征{C3, C4, C5}形成,利用了自顶向下和自底向上的路径来加强图像特征和文本特征之间的交互。

  • Text-guided CSPLayer(文本->图像).文本embedding经过max-sigmoid加权到neck特征后与原始特征concat。
  • Image-Pooling Attention.(图像->文本)。多层图像特征和文本attention再加到文本embedding中

结果

在这里插入图片描述
又快又好!V100上达到了52FPS!!!
在这里插入图片描述

核心代码:

class RepConvMaxSigmoidAttnBlock(BaseModule):"""Max Sigmoid attention block."""def __init__(self,in_channels: int,out_channels: int,embed_channels: int,guide_channels: int,kernel_size: int = 3,padding: int = 1,num_heads: int = 1,use_depthwise: bool = False,with_scale: bool = False,conv_cfg: OptConfigType = None,norm_cfg: ConfigType = dict(type='BN',momentum=0.03,eps=0.001),init_cfg: OptMultiConfig = None,use_einsum: bool = True) -> None:super().__init__(init_cfg=init_cfg)conv = DepthwiseSeparableConvModule if use_depthwise else ConvModuleassert (out_channels % num_heads == 0 andembed_channels % num_heads == 0), \'out_channels and embed_channels should be divisible by num_heads.'self.num_heads = num_headsself.head_channels = out_channels // num_headsself.use_einsum = use_einsumself.embed_conv = ConvModule(in_channels,embed_channels,1,conv_cfg=conv_cfg,norm_cfg=norm_cfg,act_cfg=None) if embed_channels != in_channels else Noneself.bias = nn.Parameter(torch.zeros(num_heads))self.num_heads = num_headsself.split_channels = embed_channels // num_headsself.guide_convs = nn.ModuleList(nn.Conv2d(self.split_channels, guide_channels, 1, bias=False)for _ in range(num_heads))self.project_conv = conv(in_channels,out_channels,kernel_size,stride=1,padding=padding,conv_cfg=conv_cfg,norm_cfg=norm_cfg,act_cfg=None)def forward(self, x: Tensor, txt_feats: Tensor = None) -> Tensor:"""Forward process."""B, C, H, W = x.shapeembed = self.embed_conv(x) if self.embed_conv is not None else xembed = list(embed.split(self.split_channels, 1))# Bx(MxN)xHxW (H*c=C, H: heads)attn_weight = torch.cat([conv(x) for conv, x in zip(self.guide_convs, embed)], dim=1)# BxMxNxHxWattn_weight = attn_weight.view(B, self.num_heads, -1, H, W)# attn_weight = torch.stack(#     [conv(x) for conv, x in zip(self.guide_convs, embed)])# BxMxNxHxW -> BxMxHxWattn_weight = attn_weight.max(dim=2)[0] / (self.head_channels**0.5)attn_weight = (attn_weight + self.bias.view(1, -1, 1, 1)).sigmoid()# .transpose(0, 1)# BxMx1xHxWattn_weight = attn_weight[:, :, None]x = self.project_conv(x)# BxHxCxHxWx = x.view(B, self.num_heads, -1, H, W)x = x * attn_weightx = x.view(B, -1, H, W)return x

ImagePoolingAttentionModule

class ImagePoolingAttentionModule(nn.Module):def __init__(self,image_channels: List[int],text_channels: int,embed_channels: int,with_scale: bool = False,num_feats: int = 3,num_heads: int = 8,pool_size: int = 3,use_einsum: bool = True):super().__init__()self.text_channels = text_channelsself.embed_channels = embed_channelsself.num_heads = num_headsself.num_feats = num_featsself.head_channels = embed_channels // num_headsself.pool_size = pool_sizeself.use_einsum = use_einsumif with_scale:self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)else:self.scale = 1.0self.projections = nn.ModuleList([ConvModule(in_channels, embed_channels, 1, act_cfg=None)for in_channels in image_channels])self.query = nn.Sequential(nn.LayerNorm(text_channels),Linear(text_channels, embed_channels))self.key = nn.Sequential(nn.LayerNorm(embed_channels),Linear(embed_channels, embed_channels))self.value = nn.Sequential(nn.LayerNorm(embed_channels),Linear(embed_channels, embed_channels))self.proj = Linear(embed_channels, text_channels)self.image_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((pool_size, pool_size))for _ in range(num_feats)])def forward(self, text_features, image_features):B = image_features[0].shape[0]assert len(image_features) == self.num_featsnum_patches = self.pool_size**2mlvl_image_features = [pool(proj(x)).view(B, -1, num_patches)for (x, proj, pool) in zip(image_features, self.projections, self.image_pools)]mlvl_image_features = torch.cat(mlvl_image_features,dim=-1).transpose(1, 2)q = self.query(text_features)k = self.key(mlvl_image_features)v = self.value(mlvl_image_features)q = q.reshape(B, -1, self.num_heads, self.head_channels)k = k.reshape(B, -1, self.num_heads, self.head_channels)v = v.reshape(B, -1, self.num_heads, self.head_channels)if self.use_einsum:attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)else:q = q.permute(0, 2, 1, 3)k = k.permute(0, 2, 3, 1)attn_weight = torch.matmul(q, k)attn_weight = attn_weight / (self.head_channels**0.5)attn_weight = F.softmax(attn_weight, dim=-1)if self.use_einsum:x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)else:v = v.permute(0, 2, 1, 3)x = torch.matmul(attn_weight, v)x = x.permute(0, 2, 1, 3)x = self.proj(x.reshape(B, -1, self.embed_channels))return x * self.scale + text_features

参考:https://github.com/AILab-CVC/YOLO-World/blob/master/yolo_world/models/layers/yolo_bricks.py

相关文章:

  • 快速排序c++java代码实现
  • 全网最简单的Java设计模式【三】工厂方法模式详解
  • 实现点击按钮导出页面pdf
  • Android super.img结构及解包和重新组包
  • Android Gradle开发与应用Gradle详细使用
  • STM32第十四课:低功耗模式和RTC实时时钟
  • 「C++系列」C++ 变量作用域
  • 谈谈检测浏览器类型
  • Jenkins 使用 Publish over SSH进行远程访问
  • p标签文本段落中因编辑器换行引起的空格问题完美解决方案
  • 【Element-UI】vue使用 this.$confirm区分取消与关闭,vue给this.$confirm设置多个按钮
  • WHAT - React Immer
  • QT学习(6)——QT中的定时器事件,两种实现方式;事件的分发event,事件过滤器
  • 【软件工程】计算机内存单位解析及换算
  • vue3中svg图标的封装与使用
  • SegmentFault for Android 3.0 发布
  • “大数据应用场景”之隔壁老王(连载四)
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • canvas 高仿 Apple Watch 表盘
  • CEF与代理
  • CSS中外联样式表代表的含义
  • FineReport中如何实现自动滚屏效果
  • input实现文字超出省略号功能
  • iOS编译提示和导航提示
  • mysql常用命令汇总
  • MySQL数据库运维之数据恢复
  • mysql外键的使用
  • nodejs:开发并发布一个nodejs包
  • SegmentFault 2015 Top Rank
  • vue.js框架原理浅析
  • 大整数乘法-表格法
  • 基于遗传算法的优化问题求解
  • 京东美团研发面经
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 为视图添加丝滑的水波纹
  • 写给高年级小学生看的《Bash 指南》
  • 一起来学SpringBoot | 第三篇:SpringBoot日志配置
  • 浅谈sql中的in与not in,exists与not exists的区别
  • %check_box% in rails :coditions={:has_many , :through}
  • (3)选择元素——(17)练习(Exercises)
  • (C#)一个最简单的链表类
  • (Redis使用系列) Springboot 使用Redis+Session实现Session共享 ,简单的单点登录 五
  • (附源码)springboot金融新闻信息服务系统 毕业设计651450
  • (附源码)ssm航空客运订票系统 毕业设计 141612
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (算法)求1到1亿间的质数或素数
  • (太强大了) - Linux 性能监控、测试、优化工具
  • (学习日记)2024.01.19
  • (已解决)Bootstrap精美弹出框模态框modal,实现js向modal传递数据
  • (转载)虚函数剖析
  • .equals()到底是什么意思?
  • .NET Core 和 .NET Framework 中的 MEF2
  • .NET Core 通过 Ef Core 操作 Mysql
  • .Net Core与存储过程(一)