当前位置: 首页 > news >正文

YOLOv7添加注意力机制和各种改进模块

YOLOv7添加注意力机制和各种改进模块代码免费下载:完整代码

添加的部分模块代码:

########CBAM
class ChannelAttentionModule(nn.Module):def __init__(self, c1, reduction=16):super(ChannelAttentionModule, self).__init__()mid_channel = c1 // reductionself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Linear(in_features=c1, out_features=mid_channel),nn.LeakyReLU(0.1, inplace=True),nn.Linear(in_features=mid_channel, out_features=c1))self.act = nn.Sigmoid()# self.act=nn.SiLU()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)maxout = self.shared_MLP(self.max_pool(x).view(x.size(0), -1)).unsqueeze(2).unsqueeze(3)return self.act(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.act = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.act(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, c1, c2):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out
##############CBAM
########SE
class SEAttention(nn.Module):def __init__(self, channel=512,reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)
########SE
#######GAM
class GAMAttention(nn.Module):# https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True, rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),kernel_size=7,padding=3),nn.BatchNorm2d(int(c1 / rate)),nn.ReLU(inplace=True),nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,kernel_size=7,padding=3),nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffleout = x * x_spatial_attreturn outdef channel_shuffle(x, groups=2):  ##shuffle channel# RESHAPE----->transpose------->FlattenB, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out = out.view(B, C, H, W)return out
#######GAM
#####NAMAttention  该注意力机制只有通道注意力机制的代码,空间的没有
import torch.nn as nn
import torch
from torch.nn import functional as Fclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual  #return xclass NAMAttention(nn.Module):def __init__(self, channels, out_channels=None, no_spatial=True):super(NAMAttention, self).__init__()self.Channel_Att = nn.Sequential(*(Channel_Att(channels)for _ in range(1)))def forward(self, x):# print(x.device)## device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')x_out1 = self.Channel_Att(x)return x_out1
#####NAMAttentionclass RepGhostBottleneck1(nn.Module):# RepGhostNeXt Bottleneckdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_outsuper().__init__()self.c_ = int(c2 * e)  # hidden channels# attention mechanism can be usedself.m = nn.Sequential(*(RepGhostBottleneck(c1, c2, 2*self.c_) for _ in range(n)))def forward(self, x):return self.m(x)

相关文章:

  • 微信聊天内容怎么监控? | 三款可以监控电脑微信聊天记录的软件大盘点
  • 每日两题 / 131. 分割回文串 42. 接雨水(LeetCode热题100)
  • HCIP的学习(24)
  • 数字化学校渠道的建造内容
  • EM算法最通俗理解
  • 【2024最新华为OD-C卷试题汇总】披萨大作战 (100分) - 支持在线评测+三语言AC题解(Python/Java/Cpp)
  • mysql仿照find_in_set写了一个replace_in_set函数,英文逗号拼接字符串指定替换
  • 《广告数据定量分析》读书笔记之理论/概论
  • [面经] 西山居非正式面试(C++)
  • 深入解析力扣166题:分数到小数(模拟长除法与字符串操作详解及模拟面试问答)
  • Nginx网页服务
  • Compose 中的 touch 事件
  • 【全开源】防伪溯源一体化管理系统源码(FastAdmin+ThinkPHP和Uniapp)
  • 【5】:三维到二维变换(模型、视图、投影)
  • 基于异构图的大规模微服务系统性能问题诊断
  • 时间复杂度分析经典问题——最大子序列和
  • Git学习与使用心得(1)—— 初始化
  • IDEA常用插件整理
  • LeetCode18.四数之和 JavaScript
  • Nodejs和JavaWeb协助开发
  • thinkphp5.1 easywechat4 微信第三方开放平台
  • webpack4 一点通
  • 闭包--闭包作用之保存(一)
  • 计算机在识别图像时“看到”了什么?
  • 技术胖1-4季视频复习— (看视频笔记)
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 前端代码风格自动化系列(二)之Commitlint
  • 巧用 TypeScript (一)
  • 新手搭建网站的主要流程
  • 用quicker-worker.js轻松跑一个大数据遍历
  • LIGO、Virgo第三轮探测告捷,同时探测到一对黑洞合并产生的引力波事件 ...
  • shell使用lftp连接ftp和sftp,并可以指定私钥
  • ​ 全球云科技基础设施:亚马逊云科技的海外服务器网络如何演进
  • (3)选择元素——(17)练习(Exercises)
  • (arch)linux 转换文件编码格式
  • (二)windows配置JDK环境
  • (附源码)springboot猪场管理系统 毕业设计 160901
  • (附源码)计算机毕业设计高校学生选课系统
  • (论文阅读11/100)Fast R-CNN
  • (顺序)容器的好伴侣 --- 容器适配器
  • (一)SpringBoot3---尚硅谷总结
  • (一)Thymeleaf用法——Thymeleaf简介
  • (杂交版)植物大战僵尸
  • **PyTorch月学习计划 - 第一周;第6-7天: 自动梯度(Autograd)**
  • .CSS-hover 的解释
  • .htaccess配置常用技巧
  • .NET CF命令行调试器MDbg入门(三) 进程控制
  • .NET I/O 学习笔记:对文件和目录进行解压缩操作
  • .NET MAUI学习笔记——2.构建第一个程序_初级篇
  • .NET 中让 Task 支持带超时的异步等待
  • .NET多线程执行函数
  • .Net中ListT 泛型转成DataTable、DataSet
  • .pings勒索病毒的威胁:如何应对.pings勒索病毒的突袭?
  • [ C++ ] 继承
  • [ IOS ] iOS-控制器View的创建和生命周期