当前位置: 首页 > news >正文

论文阅读NAM:Normalization-based Attention Module

Abstarct

识别不太显著的特征是模型压缩的关键。然而,在革命性的注意力机制中却没有对其进行研究。在这项工作中,我们提出了一种新的基于归一化的注意力模块(NAM),它抑制了不太显著的权重。它对注意力模块应用了权重稀疏性惩罚,从而使它们在保持类似性能的同时具有更高的计算效率。与Resnet和Mobilenet上的其他三种注意力机制的比较表明,我们的方法具有更高的准确性。

Introduction

注意机制是近年来研究的热点之一 (Wang et al.[2017], Hu et al. [2018], Park et al. [2018], Woo et al. [2018], Gao et al. [2019]).)。它有助于深度神经网络抑制不太显著的像素或通道。先前的许多研究都集中在通过注意力操作捕捉显著特征上(Zhang et al. [2020], Misra et al. [2021])。这些方法成功地利用了来自不同维度特征的相互信息。然而,它们缺乏对权重的贡献因素的考虑,这能够进一步抑制不重要的通道或像素。受Liu et al. [2017]的启发,我们旨在利用权重的贡献因素来改善注意力机制。我们使用批量归一化的比例因子,该比例因子使用标准偏差来表示权重的重要性。这可以避免添加SE、BAM和CBAM中使用的完全连接层和卷积层。因此,我们提出了一种有效的注意力机制——基于归一化的注意力模块(NAM)。

Related work

许多先前的工作试图通过抑制不重要的权重来提高神经网络的性能。挤压和激励网络(SENet)(Hu et al[2018])将空间信息集成到通道特征响应中,并使用两个多层感知器(MLP)层计算相应的注意力。后来,瓶颈注意力模块(BAM)(Park et al. [2018]) b并行构建了分离的空间和通道子模块,它们可以嵌入到每个瓶颈块中。卷积块注意力模块(CBAM) (Woo et al. [2018]) 提供了一种按顺序嵌入通道和空间注意力子模块的解决方案,为了避免忽视跨维度交互,三重注意力模块(TAM)) (Misra et al. [2021]) 通过旋转特征图来考虑维度相关性。然而,这些工作忽略了来自训练的调谐权重的信息。因此,我们的目标是通过利用训练的模型权重的方差测量来突出显著特征。

Methodology

我们提出了NAM作为一种高效和轻量级的注意机制。我们采用了CBAM的模块集成(Woo et al[2018]),并重新设计了通道和空间注意力子模块。然后,在每个网络块的末端嵌入一个NAM模块。对于残差网络,它嵌入在残差结构的末端。对于通道注意力子模块,我们使用批量归一化(BN)的比例因子(Ioffe and Szegedy [2015]),如公式(1)所示。比例因子测量信道的方差并指示它们的重要性。

B_{out}=BN(BN_{in} )=\gamma \frac{B_{in}-\mu\mathcal{_{B}}}{\sigma _{\mathcal{_{B}}}^{2}+\epsilon}                                   (1)

其中\mu\mathcal{_{B}}\sigma\mathcal{_{B}}分别为小批量\mathcal{B}的平均值和标准偏差;γ和β是可训练的仿射变换参数(尺度和偏移)(Ioffe and Szegedy [2015])。通道注意力子模块如图1和方程(2)所示,其中M_c表示输出特征。γ是每个通道的比例因子,权重为W_{\gamma } =\gamma _{i} / {\textstyle \sum_{j=0}^{}\gamma _{j} }。我们还将BN的比例因子应用于空间维度,以测量像素的重要性。我们将其命名为像素归一化。相应的空间注意力子模块如图2和方程(3)所示,其中输出表示为M_s\lambda是比例因子,权重为W_{\lambda } =\lambda _{i} / {\textstyle \sum_{j=0}^{}\lambda _{j} }。为了抑制不太显著的权重,我们将正则化项添加到损失函数中,如方程(4)所示(Liu et al[2017]),其x表示输入,γ是输出;表示网络权重;l(\cdot )是损失函数;g(\cdot )l_1范数罚函数;p是平衡g(\gamma)g(\lambda)的惩罚。

M_c=sigmoid(W_\gamma(BN(F_1)) )                   (2)

M_s=sigmoid(W_\lambda(BN(F_2)) )                   (3)

Loss=\sum_{(x,y)}^{} l(f(x,W),g)+p\sum g(\gamma )+p\sum g(\lambda )                    (4)

Experiment

在本节中,我们比较了NAM与SE、BAM、CBAM和TAM在ResNet和MobileNet中的性能。我们在一个集群上使用四个Nvidia Tesla V100 GPU来评估每种方法。我们首先在CIFAR-100上运行ResNet50(Krizhevsky等人[2009]),并使用与CBAM相同的预处理和训练配置(Woo等人[2018]),p为0.0001。表1中的比较表明,单独使用通道或空间注意力的NAM优于其他四种注意力机制。然后,我们在ImageNet上运行MobileNet(Deng等人[2009]),因为它是图像分类基准的标准数据集之一。我们将p设置为0.001,其余配置与CBAM相同。表2中的比较表明,信道和空间注意力相结合的NAM优于其他三种计算复杂度相似的NAM。

Conclusion

我们提出了一个NAM模块,该模块通过抑制不太显著的特征来提高效率。我们的实验表明,NAM在ResNet和MobileNet上都提供了效率增益。我们正在对NAM在积分变化和超参数调整方面的性能进行详细分析。我们还计划利用不同的模型压缩技术对 NAM 进行优化,以提高其效率。未来,我们将研究它对其他深度学习架构和应用的影响。

Code

import torch.nn as nn
import torch
from torch.nn import functional as Fclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual #return xclass Att(nn.Module):def __init__(self, channels,shape, out_channels=None, no_spatial=True):super(Att, self).__init__()self.Channel_Att = Channel_Att(channels)def forward(self, x):x_out1=self.Channel_Att(x)return x_out1  

相关文章:

  • 错误:comparison method violates its general contract
  • 智慧应急:构建全方位、立体化的安全保障网络
  • vue使用gitshot生成gif
  • 【Langchain多Agent实践】一个有推销功能的旅游聊天机器人
  • 如何在Window系统部署BUG管理软件并结合内网穿透实现远程管理本地BUG
  • SpringMVC 学习(二)之第一个 SpringMVC 案例
  • 解释什么是内连接、左连接和右连接,并给出每种连接的SQL示例
  • day03_登录注销(前端接入登录,异常处理, 图片验证码,获取用户信息接口,退出功能)
  • 【pytorch矩阵应用】
  • 哈工大中文mistral介绍(Chinese-Mixtral-8x7B)
  • Redis实现滑动窗口限流
  • 微服务之qiankun主项目+子项目搭建
  • C++:封装
  • Pyglet综合应用|推箱子游戏之关卡图片载入内存
  • JMETER与它的组件们
  • [NodeJS] 关于Buffer
  • 2018一半小结一波
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • Android框架之Volley
  • codis proxy处理流程
  • css系列之关于字体的事
  • CSS中外联样式表代表的含义
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • gops —— Go 程序诊断分析工具
  • JavaScript HTML DOM
  • JavaScript服务器推送技术之 WebSocket
  • nodejs实现webservice问题总结
  • SpriteKit 技巧之添加背景图片
  • 从setTimeout-setInterval看JS线程
  • 如何选择开源的机器学习框架?
  • 如何用Ubuntu和Xen来设置Kubernetes?
  • 如何用vue打造一个移动端音乐播放器
  • 原生JS动态加载JS、CSS文件及代码脚本
  • 终端用户监控:真实用户监控还是模拟监控?
  • JavaScript 新语法详解:Class 的私有属性与私有方法 ...
  • ​LeetCode解法汇总1410. HTML 实体解析器
  • ​TypeScript都不会用,也敢说会前端?
  • (多级缓存)缓存同步
  • (附源码)ssm户外用品商城 毕业设计 112346
  • (附源码)计算机毕业设计SSM教师教学质量评价系统
  • (机器学习-深度学习快速入门)第一章第一节:Python环境和数据分析
  • (考研湖科大教书匠计算机网络)第一章概述-第五节1:计算机网络体系结构之分层思想和举例
  • (亲测有效)解决windows11无法使用1500000波特率的问题
  • (十六)一篇文章学会Java的常用API
  • (算法)Travel Information Center
  • (提供数据集下载)基于大语言模型LangChain与ChatGLM3-6B本地知识库调优:数据集优化、参数调整、Prompt提示词优化实战
  • (转)AS3正则:元子符,元序列,标志,数量表达符
  • (转)mysql使用Navicat 导出和导入数据库
  • (转)清华学霸演讲稿:永远不要说你已经尽力了
  • ***微信公众号支付+微信H5支付+微信扫码支付+小程序支付+APP微信支付解决方案总结...
  • **python多态
  • .Family_物联网
  • .net 8 发布了,试下微软最近强推的MAUI
  • .NET I/O 学习笔记:对文件和目录进行解压缩操作
  • .net php 通信,flash与asp/php/asp.net通信的方法