当前位置: 首页 > news >正文

SMA2:代码实现详解——Image Encoder篇(Hiera章)

SMA2:代码实现详解——Image Encoder篇(Hiera)

写在前面

大家在SMA2:代码实现详解——Image Encoder篇(FpnNeck)下的留言我已收到,感谢大家的支持,后面如果遇到比较难以讲清的部分可能会使用视频的形式。博主最近要准备秋招,更新可能会慢许多,希望大家能谅解。

言归正传,在SMA2:代码实现详解——Image Encoder篇(FpnNeck)中,我们已经知道了SMA2的整体架构,并且介绍了Image Encoder组件中的FpnNeck。这一篇博客我们就来详细介绍Image Encoder的基本骨架backbone——Hiera

Hiera介绍

Hiera是文章Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles中提出的一种分层视觉Transformer架构。它不仅可以处理图像,而且这个架构可以应用于视频。Hiera是一个纯粹的简单分层ViT模型,不存在任何卷积、移位或者十字窗口操作,仅有Transformer结构组件。它比之前跨多个模型大小、领域和任务的工作更快、更准确。
在这里插入图片描述

Hiera与MAE(Masked AutoEncoder)

MAE(Masked AutoEncoder, 掩码自编码器)

图像MAE由论文Masked Autoencoders Are Scalable Vision Learners提出,它表明,MAE是计算机视觉的可扩展自监督学习器。方法非常简单:屏蔽输入图像的随机Patch并重建丢失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器架构,其中的编码器仅对Patch的可见子集(没有掩码标记)进行操作,而轻量级解码器可根据潜在表示和掩码标记重建原始图像。作者发现屏蔽高比例的输入图像(例如 75%)会产生一项不简单且有意义的自我监督任务。将这两种设计结合起来能够高效且有效地训练大型模型:加速训练(3 倍或更多)并提高准确性。可扩展方法允许学习泛化良好的高容量模型:例如,在仅使用 ImageNet-1K 数据的方法中,普通 ViT-Huge 模型实现了最佳准确率 (87.8%)。下游任务中的传输性能优于监督预训练,并显示出有希望的扩展行为。

Hiera便使用了MAE的方式进行训练。

Hiera架构

在这里插入图片描述

选择使用像MAE(如图所示)这样的强代理任务(pretext task)来教导模型。 Hiera完全由标准ViT块组成。为了提高效率,在前两个阶段使用“掩模单元”内的局部注意力,其余阶段使用全局注意力(Global Attention)。在每个阶段转换中,Q和跳跃连接的特征通过线性层加倍,空间维度通过2×2最大池池化。

SMA2中Hiera(HieraDet)的实现

class Hiera(nn.Module):"""Reference: https://arxiv.org/abs/2306.00989"""def __init__(self, ...):...self.blocks = nn.ModuleList()for i in range(depth):dim_out = embed_dim...block = MultiScaleBlock(dim=embed_dim,dim_out=dim_out,num_heads=num_heads,drop_path=dpr[i],q_stride=self.q_stride if i in self.q_pool_blocks else None,window_size=window_size,)embed_dim = dim_outself.blocks.append(block)def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:h, w = hwwindow_embed = self.pos_embed_windowpos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])pos_embed = pos_embed.permute(0, 2, 3, 1)return pos_embeddef forward(self, x: torch.Tensor) -> List[torch.Tensor]:x = self.patch_embed(x)# x: (B, H, W, C)# Add pos embedx = x + self._get_pos_embed(x.shape[1:3])outputs = []for i, blk in enumerate(self.blocks):x = blk(x)if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):feats = x.permute(0, 3, 1, 2)outputs.append(feats)return outputs

首先,Hiera先将图片划分并映射为patch嵌入向量(上述代码62行),然后计算位置信息并相加(代码第66行)。值得注意的是,SMA2在实现Hiera中位置嵌入时,参照了Window Attention is Bugged: How not to Interpolate Position Embeddings一文,他们发现在使用窗口注意力的同时插值位置嵌入是错误的。Hiera和ViTDet两者确实都存在此错误。于是作者提出了一种简单的绝对窗口位置嵌入策略,它彻底解决了Hiera中的错误,并提高了ViTDet中模型的速度和性能。

代码的68-75行实际上就是Hiera主体ViT块的处理,值得关注的只有带有Q pooling的ViT块,这是在MultiScaleBlock中实现的。

class PatchEmbed(nn.Module):"""Image to Patch Embedding."""def __init__(self,kernel_size: Tuple[int, ...] = (7, 7),stride: Tuple[int, ...] = (4, 4),padding: Tuple[int, ...] = (3, 3),in_chans: int = 3,embed_dim: int = 768,):"""Args:kernel_size (Tuple): kernel size of the projection layer.stride (Tuple): stride of the projection layer.padding (Tuple): padding size of the projection layer.in_chans (int): Number of input image channels.embed_dim (int):  embed_dim (int): Patch embedding dimension."""super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.proj(x)# B C H W -> B H W Cx = x.permute(0, 2, 3, 1)return x

PatchEmbed模块将图片的形状(B,C,H,W)转化为更常见的适用于Transformer处理的形状(B, H, W, C),因为后面经过VIT块时会要求(B,L,C)的形式。实际上,这个模块的卷积映射继承了ViT的做法,直接利用了卷积的特性,通过指定Kernel_size与strides隐式划分了窗口,并且完成了线性变换得到patch enmbedding

值得注意的是位置嵌入的计算:

class Hiera(nn.Module):def __init__(...)super().__init__()...self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_sizeself.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))...def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:h, w = hwwindow_embed = self.pos_embed_windowpos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])pos_embed = pos_embed.permute(0, 2, 3, 1)return pos_embed

代码第18行是计算全局的可学习位置嵌入。第19行加号的右边window_embed.tile(...)是计算每个window内的局部位置编码,每个window的位置编码都是相同的。我们可以使用matplotlib做一个可视化的样例,可能更容易理解。示例如下(由于代码中是零初始化,不太好展示,这里我选择随机初始化来展示):

在这里插入图片描述

从左到右依次为全局编码、局部编码和最终位置编码。

接下来我们来看MultiScaleBlock的实现:

class MultiScaleBlock(nn.Module):def __init__(self,dim: int,dim_out: int,num_heads: int,mlp_ratio: float = 4.0,drop_path: float = 0.0,norm_layer: Union[nn.Module, str] = "LayerNorm",q_stride: Tuple[int, int] = None,act_layer: nn.Module = nn.GELU,window_size: int = 0,):super().__init__()if isinstance(norm_layer, str):norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)self.dim = dimself.dim_out = dim_outself.norm1 = norm_layer(dim)self.window_size = window_sizeself.pool, self.q_stride = None, q_strideif self.q_stride:self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)self.attn = MultiScaleAttention(dim,dim_out,num_heads=num_heads,q_pool=self.pool,)self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()self.norm2 = norm_layer(dim_out)self.mlp = MLP(dim_out,int(dim_out * mlp_ratio),dim_out,num_layers=2,activation=act_layer,)if dim != dim_out:self.proj = nn.Linear(dim, dim_out)def forward(self, x: torch.Tensor) -> torch.Tensor:shortcut = x  # B, H, W, Cx = self.norm1(x)# Skip connectionif self.dim != self.dim_out:shortcut = do_pool(self.proj(x), self.pool)# Window partitionwindow_size = self.window_sizeif window_size > 0:H, W = x.shape[1], x.shape[2]x, pad_hw = window_partition(x, window_size)# Window Attention + Q Pooling (if stage change)x = self.attn(x)if self.q_stride:# Shapes have changed due to Q poolingwindow_size = self.window_size // self.q_stride[0]H, W = shortcut.shape[1:3]pad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizepad_hw = (H + pad_h, W + pad_w)# Reverse window partitionif self.window_size > 0:x = window_unpartition(x, window_size, pad_hw, (H, W))x = shortcut + self.drop_path(x)# MLPx = x + self.drop_path(self.mlp(self.norm2(x)))return xdef window_partition(x, window_size):"""Partition into non-overlapping windows with padding if needed.Args:x (tensor): input tokens with [B, H, W, C].window_size (int): window size.Returns:windows: windows after partition with [B * num_windows, window_size, window_size, C].(Hp, Wp): padded height and width before partition"""B, H, W, C = x.shapepad_h = (window_size - H % window_size) % window_sizepad_w = (window_size - W % window_size) % window_sizeif pad_h > 0 or pad_w > 0:x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))Hp, Wp = H + pad_h, W + pad_wx = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)windows = (x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C))return windows, (Hp, Wp)def do_pool():... #(B, H, W, C) -> (B, H', W' C)

MultiScaleBlockMultiScaleAttentionMLP构成,有经验的小伙伴看到注意力机制和MLP,显然得出它是一个Transformer。第60-63行代码就是根据每个stage给定的window size划分patch。
在这里插入图片描述

而且针对于每个Stage的交界,都使用Q pooling,这在MultiScaleAttention中实现。

class MultiScaleAttention(nn.Module):def __init__(self,dim: int,dim_out: int,num_heads: int,q_pool: nn.Module = None,):super().__init__()self.dim = dimself.dim_out = dim_outself.num_heads = num_headsself.q_pool = q_poolself.qkv = nn.Linear(dim, dim_out * 3)self.proj = nn.Linear(dim_out, dim_out)def forward(self, x: torch.Tensor) -> torch.Tensor:B, H, W, _ = x.shape# qkv with shape (B, H * W, 3, nHead, C)qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)# q, k, v with shape (B, H * W, nheads, C)q, k, v = torch.unbind(qkv, 2)# Q pooling (for downsample at stage changes)if self.q_pool:q = do_pool(q.reshape(B, H, W, -1), self.q_pool)H, W = q.shape[1:3]  # downsampled shapeq = q.reshape(B, H * W, self.num_heads, -1)# Torch's SDPA expects [B, nheads, H*W, C] so we transposex = F.scaled_dot_product_attention(q.transpose(1, 2),k.transpose(1, 2),v.transpose(1, 2),)# Transpose backx = x.transpose(1, 2)x = x.reshape(B, H, W, -1)x = self.proj(x)return x

代码19-23以及31-41都是比较传统的自注意力机制的计算了。
而所谓的Q pooling在26-30行,只是对Q向量转换为宽高的形状(B, H*W,)->(B, H, W, …),然后进行池化。其实对于H和W,它们应该是我们之前指定的window size。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Proxyless Service Mesh:下一代微服务架构体系
  • 【HarmonyOS NEXT】实现网络图片保存到手机相册
  • 音视频直播应用场景探讨之RTMP推流还是GB28181接入?
  • javase复习day22泛型、set、数据结构
  • USBCANFD卡在新能源BMS上位机的应用
  • Android CustomDialog圆角背景不生效的问题
  • String字符串
  • uniapp(H5)设置反向代理,设置成功后页面报错
  • AI教你学Python 第4天:函数和模块
  • MySQL下载安装
  • 可信多视图分类(TCM ETCM)算法实现数字序列的分类---基因致病的诊断
  • JAVA学习-练习试用Java实现“子集 II”
  • 代码随想录训练营 Day58打卡 图论part08 拓扑排序 dijkstra朴素版 + 堆优化版
  • 机器学习和深度学习的常见概念总结(多原创图)
  • 设计模式(Design Patterns)
  • Linux快速复制或删除大量小文件
  • Mac转Windows的拯救指南
  • PHP 的 SAPI 是个什么东西
  • python 学习笔记 - Queue Pipes,进程间通讯
  • React系列之 Redux 架构模式
  • vue从入门到进阶:计算属性computed与侦听器watch(三)
  • 阿里云购买磁盘后挂载
  • 代理模式
  • 观察者模式实现非直接耦合
  • 理清楚Vue的结构
  • 如何优雅的使用vue+Dcloud(Hbuild)开发混合app
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 算法之不定期更新(一)(2018-04-12)
  • 移动端 h5开发相关内容总结(三)
  • 06-01 点餐小程序前台界面搭建
  • 阿里云重庆大学大数据训练营落地分享
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • # Redis 入门到精通(八)-- 服务器配置-redis.conf配置与高级数据类型
  • ## 临床数据 两两比较 加显著性boxplot加显著性
  • #pragma pack(1)
  • #考研#计算机文化知识1(局域网及网络互联)
  • #我与Java虚拟机的故事#连载02:“小蓝”陪伴的日日夜夜
  • (2)关于RabbitMq 的 Topic Exchange 主题交换机
  • (2024,Vision-LSTM,ViL,xLSTM,ViT,ViM,双向扫描)xLSTM 作为通用视觉骨干
  • (22)C#传智:复习,多态虚方法抽象类接口,静态类,String与StringBuilder,集合泛型List与Dictionary,文件类,结构与类的区别
  • (4)STL算法之比较
  • (Redis使用系列) SpirngBoot中关于Redis的值的各种方式的存储与取出 三
  • (附源码)小程序儿童艺术培训机构教育管理小程序 毕业设计 201740
  • (九十四)函数和二维数组
  • (一)u-boot-nand.bin的下载
  • (已更新)关于Visual Studio 2019安装时VS installer无法下载文件,进度条为0,显示网络有问题的解决办法
  • (幽默漫画)有个程序员老公,是怎样的体验?
  • (转)mysql使用Navicat 导出和导入数据库
  • (转载)深入super,看Python如何解决钻石继承难题
  • .net core webapi Startup 注入ConfigurePrimaryHttpMessageHandler
  • .NET Core中如何集成RabbitMQ
  • .net安装_还在用第三方安装.NET?Win10自带.NET3.5安装
  • .Net程序帮助文档制作
  • .NET精简框架的“无法找到资源程序集”异常释疑