当前位置: 首页 > news >正文

【论文笔记】独属于CV的注意力机制CBAM-Convolutional Block Attention Module

目录

写在前面

一、基数和宽度

二、通道注意力模块(Channel Attention Module)

三、空间注意力模块(Spatial Attention Module)

四、CBAM(Convolutional Block Attention Module)

五、总结


写在前面

        CBAM论文地址:https://arxiv.org/abs/1807.06521

        CBAM(Convolutional Block Attention Module)是2018年被提出的,不同于ViT的Attention,CBAM是为CNN量身定做的Attention模块,实现简单、效果好,你值得拥有。

        为了提高CNN的性能,我们可以从深度(depth)、宽度(width)和基数(cardinality)三个方面入手。深度很好理解,就是模型的层数,层数越多模型越深。下面说一说基数和宽度的区别。

一、基数和宽度

基数(cardinality):指的是并行分支的数量。

宽度(width):指每一层卷积的卷积核数量(即输出特征图的通道数)。

举个GoogLeNet的例子:

        GoogLeNet 的设计中,Inception模块通过组合多个不同大小的卷积核(例如 1x1、3x3、5x5)和池化操作来提取不同尺度的特征。

        增加“宽度”的效果: 我们有一个 Inception 模块,包含 1x1、3x3 和 5x5 的卷积层,以及一个 3x3 的最大池化层。如果我们在该 Inception 模块中增加每个卷积操作的通道数,例如将 1x1 卷积层的输出通道数从 32 增加到 64,将 3x3 卷积层的输出通道数从 64 增加到 128,这种操作就增加了网络的“宽度”。增加“宽度”意味着每个 Inception 模块可以提取更多的特征信息,但同时也增加了计算成本。

        增加基数: 如果我们将每个卷积和池化操作进一步拆分为多个组,例如在每个卷积操作中使用组卷积(group convolution),那么这些并行组卷积操作的数量就类似于增加了“基数”。每个组卷积操作都是一个独立的路径,这些路径的数量增加就代表了基数的增加。基数不仅节省了参数的总数,而且比深度和宽度这两个因素具有更强的表示能力。

        可以看下图,蓝色的线表示模型的基数,红色的数字表示宽度。

        CBAM由两个顺序的子模块组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。CAM解决的问题是重要的信息是什么(‘what’ is meaningful given an input image)、SAM解决重要的信息在哪里(‘where’ is an informative part)。这两个模块都使用了增加基数的方式,提升模型的表达能力。

二、通道注意力模块(Channel Attention Module

        我们知道CNN的每个通道可以提取不同的特征(也就是Feature Map),通道注意力模块(Channel Attention Module)的主要作用是自适应地调整和增强输入Feature Map中每个通道的重要性。它通过学习每个通道对于当前任务的重要性权重,从而对不同通道进行加权,增强关键信息的表达,同时抑制不相关或冗余的特征。这种机制能够使神经网络更高效地利用信息,提高模型的性能和表达能力。

        通俗的说,CAM就是判断每个Feature Map的重要程度。即下图中相同颜色的部分会有一个标记重要程度的权重。

        CAM使用并行的平均池化和最大池化,每个池化分别经过两个卷积模块,然后相加经过sigmoid得到注意力概率图,公式如下:

        其中σ为Sigmoid型函数,MLP这里使用的是两个卷积层,权重W_0W_1是共享的。

        公式不直观,示意图如下,假设输入是一个32通道244x244的Feature Map,输出是32x1x1,表示32个Feature Map的重要程度。

        使用pooling可以捕获每个通道特征的强度。平均池化(Average Pooling) 能够捕获每个通道的整体激活分布信息。它反映了一个通道中所有特征点的平均响应,能够平滑地代表特征的整体强度。最大池化(Max Pooling) 捕获的是特征图中的最强激活信号。它能够突出特征图中最显著的特征,强调特征中的极端值。

        平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,通道注意力模块能够更好地理解哪些特征通道在给定输入图像中最重要。

        两种池化之后的卷积层共享参数,而且两个卷积层的中间维度使用in_planes//16,最大限度的减少参数,因为注意力模块只提供一个注意力概率图,提取特征并不是它的首要任务,所以不需要太多参数。

        代码示例:

class ChannelAttention(nn.Module):def __init__(self, in_planes):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),nn.ReLU(),nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)

        

三、空间注意力模块(Spatial Attention Module

        SAM首先沿着通道轴分别使用平均池化和最大池化操作,并将它们连接起来,然后经过一个7x7的卷积层将通道数变成1,最后是Sigmoid得到结果,是一个宽高和输入一致通道数为1的概率图。

        SAM就是判断Feature Map中每个部分(像素)的重要程度,需要综合每个部分所有通道的特征。即下图中相同颜色的部分会有一个标记重要程度的权重。

公式如下:

       其中,σ表示Sigmoid函数,f^{7*7}表示大小为7×7的卷积运算。

        补充一下,这里的卷积核大小是7x7,而上面CAM的是1x1,这是因为CAM关注的是整个特征图中的全局通道信息1x1 卷积核适合这种只需在通道维度上操作的情况。SAM关注的是特征图中的局部和全局空间信息7x7 卷积核适合捕捉空间维度上的局部和全局特征。

        示意图如下,仍然假设输入是一个32通道244x244的Feature Map,输出是1x244x244。

        同时使用平均池化和最大池化的原因和CAM类似,平均池化提供了特征分布的全局性信息,而最大池化提供了最显著特征的信息。通过结合这两种池化方法,SAM能够更好地理解Feature Map中哪些区域最重要。

四、CBAM(Convolutional Block Attention Module)

        有了CAM和SAM,就剩最后一个问题,这两个模块怎么摆放,是并行放置还是顺行方式呢?作者发现顺序排列比平行排列的效果更好。下面是CBAM完整的结构图,我们随意在CBAM之前放几个卷积层:

CBAM完整的代码:

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_planes):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),nn.ReLU(),nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class DemoNet(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(DemoNet, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.ca = ChannelAttention(planes)self.sa = SpatialAttention()def forward(self, x):out = self.conv1(x)out = self.relu(out)out = self.conv2(out)out = self.ca(out) * outout = self.sa(out) * outout = self.relu(out)return outif __name__ == '__main__':input = torch.randn(1, 3, 224, 224)demo_net = DemoNet(inplanes=3, planes=32)output = demo_net(input)print(output)

        这里还有一个CBAM应用到ResNet的完整例子:https://github.com/luuuyi/CBAM.PyTorch

五、总结

1.CBAM是一个轻量级和通用的模块,它可以无缝地集成到任何CNN架构中,而开销可以忽略不计,并且可以与基础CNN一起进行端到端训练;

2.通道注意力模块(Channel Attention Module)关注每个通道的Feature Map的重要程度;

3.空间注意力模块(Spatial Attention Module)关注Feature Map每个部分(像素)的重要程度;

4.CBAM由通道注意力模块和空间注意力模块两个模块组成,同时兼顾了通道与空间特征的表达,相比传统的卷积层参数更少、效果更好。

        CBAM就介绍到这里,关注不迷路(*^▽^*)

关注订阅号了解更多精品文章

交流探讨、商务合作请加微信

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • SpringBoot调用通义千问
  • Scratch编程新纪元:网络请求与数据解析的探索之旅
  • 13-springcloud gateway集成nacos实现负载均衡
  • 【0-1背包】力扣416. 分割等和子集
  • 大模型本地化部署2-Docker部署MaxKB
  • Unity(2022.3.41LTS) - 网格,纹理,材质
  • Clickhouse集群化(三)集群化部署
  • 云计算day32
  • Windows系统安装MySQL
  • 2024 Ollama 一站式解决在Windows系统安装、使用、定制服务与实战案例
  • 线性代数:如何由AB=E 推出 BA=AB?
  • 【有来开源组织】开发规范手册
  • 【开端】 进行页面升级或维护时不影响用户体验NGINX配置
  • 影像设备国产替代究竟有多重要?这家企业提前布局8K时代
  • object.defineProperty用法
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • 0基础学习移动端适配
  • const let
  • Cumulo 的 ClojureScript 模块已经成型
  • es的写入过程
  • GitUp, 你不可错过的秀外慧中的git工具
  • golang 发送GET和POST示例
  • JavaScript设计模式与开发实践系列之策略模式
  • magento 货币换算
  • MySQL Access denied for user 'root'@'localhost' 解决方法
  • mysql innodb 索引使用指南
  • Redis提升并发能力 | 从0开始构建SpringCloud微服务(2)
  • windows下mongoDB的环境配置
  • 基于遗传算法的优化问题求解
  • 前端js -- this指向总结。
  • 让你成为前端,后端或全栈开发程序员的进阶指南,一门学到老的技术
  • 如何使用 OAuth 2.0 将 LinkedIn 集成入 iOS 应用
  • 数据结构java版之冒泡排序及优化
  • 数据科学 第 3 章 11 字符串处理
  • 体验javascript之美-第五课 匿名函数自执行和闭包是一回事儿吗?
  • 为视图添加丝滑的水波纹
  • 小程序滚动组件,左边导航栏与右边内容联动效果实现
  • 原生js练习题---第五课
  • 自动记录MySQL慢查询快照脚本
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ​VRRP 虚拟路由冗余协议(华为)
  • ​zookeeper集群配置与启动
  • # 数仓建模:如何构建主题宽表模型?
  • #我与Java虚拟机的故事#连载04:一本让自己没面子的书
  • $.type 怎么精确判断对象类型的 --(源码学习2)
  • (003)SlickEdit Unity的补全
  • (13):Silverlight 2 数据与通信之WebRequest
  • (C++20) consteval立即函数
  • (CPU/GPU)粒子继承贴图颜色发射
  • (HAL)STM32F103C6T8——软件模拟I2C驱动0.96寸OLED屏幕
  • (不用互三)AI绘画工具应该如何选择
  • (附源码)php新闻发布平台 毕业设计 141646
  • (接口封装)
  • (贪心 + 双指针) LeetCode 455. 分发饼干
  • (一)Kafka 安全之使用 SASL 进行身份验证 —— JAAS 配置、SASL 配置