当前位置: 首页 > news >正文

【深度学习】注意力机制(三)

本文介绍一些注意力机制的实现,包括EMHSA/SA/SGE/AFT/Outlook Attention。

【深度学习】注意力机制(一)

【深度学习】注意力机制(二)

【深度学习】注意力机制(四)

【深度学习】注意力机制(五)​​​​​​​

目录

一、EMHSA(Efficient Multi-Head Self-Attention)

二、SA(SHUFFLE ATTENTION)

三、SGE(Spatial Group-wise Enhance)

四、AFT(Attention Free Transformer)

五、Outlook Attention


一、EMHSA(Efficient Multi-Head Self-Attention)

论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass EMSA(nn.Module):def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True):super(EMSA, self).__init__()self.H=Hself.W=Wself.fc_q = nn.Linear(d_model, h * d_k)self.fc_k = nn.Linear(d_model, h * d_k)self.fc_v = nn.Linear(d_model, h * d_v)self.fc_o = nn.Linear(h * d_v, d_model)self.dropout=nn.Dropout(dropout)self.ratio=ratioif(self.ratio>1):self.sr=nn.Sequential()self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model)self.sr_ln=nn.LayerNorm(d_model)self.apply_transform=apply_transform and h>1if(self.apply_transform):self.transform=nn.Sequential()self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1))self.transform.add_module('softmax',nn.Softmax(-1))self.transform.add_module('in',nn.InstanceNorm2d(h))self.d_model = d_modelself.d_k = d_kself.d_v = d_vself.h = hself.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):b_s, nq ,c = queries.shapenk = keys.shape[1]q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)if(self.ratio>1):x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,Wx=self.sr_conv(x) #bs,c,h,wx=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',cx=self.sr_ln(x)k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, n')v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, n', d_v)else:k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)if(self.apply_transform):att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')att = self.transform(att) # (b_s, h, nq, n')else:att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, n')att = torch.softmax(att, -1) # (b_s, h, nq, n')if attention_weights is not None:att = att * attention_weightsif attention_mask is not None:att = att.masked_fill(attention_mask, -np.inf)att=self.dropout(att)out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)out = self.fc_o(out)  # (b_s, nq, d_model)return out

二、SA(SHUFFLE ATTENTION)

论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512,reduction=16,G=8):super().__init__()self.G=Gself.channel=channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid=nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()#group into subfeaturesx=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w#channel_splitx_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w#channel attentionx_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1x_channel=x_0*self.sigmoid(x_channel)#spatial attentionx_spatial=self.gn(x_1) #bs*G,c//(2*G),h,wx_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,wx_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w# concatenate along channel axisout=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,wout=out.contiguous().view(b,-1,h,w)# channel shuffleout = self.channel_shuffle(out, 2)return out

三、SGE(Spatial Group-wise Enhance)

论文:Spatial Group-wise Enhance: Improving Semanti

如下图:

代码如下(代码连接):

import torch
import torch.nn as nnclass SpatialGroupEnhance(nn.Module):def __init__(self, groups = 64):super(SpatialGroupEnhance, self).__init__()self.groups   = groupsself.avg_pool = nn.AdaptiveAvgPool2d(1)self.weight   = Parameter(torch.zeros(1, groups, 1, 1))self.bias     = Parameter(torch.ones(1, groups, 1, 1))self.sig      = nn.Sigmoid()def forward(self, x): # (b, c, h, w)b, c, h, w = x.size()x = x.view(b * self.groups, -1, h, w) xn = x * self.avg_pool(x)xn = xn.sum(dim=1, keepdim=True)t = xn.view(b * self.groups, -1)t = t - t.mean(dim=1, keepdim=True)std = t.std(dim=1, keepdim=True) + 1e-5t = t / stdt = t.view(b, self.groups, h, w)t = t * self.weight + self.biast = t.view(b * self.groups, 1, h, w)x = x * self.sig(t)x = x.view(b, c, h, w)return x

四、AFT(Attention Free Transformer)

论文:An Attention Free Transformer

如下图:

代码如下(代码连接):

import torch, math
from torch import nn, einsum
import torch.nn.functional as F    class AFTFull(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))nn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Ytclass AFTSimple(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT FullNumber of Heads is 1 as done in the paper.'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)'''From the paper'''weights = torch.mul(torch.softmax(K, 1), V).sum(dim=1, keepdim=True)Q_sig = torch.sigmoid(Q)Yt = torch.mul(Q_sig, weights)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Ytclass AFTLocal(nn.Module):def __init__(self, max_seqlen, dim, hidden_dim=64, s=256):super().__init__()'''max_seqlen: the maximum number of timesteps (sequence length) to be fed indim: the embedding dimension of the tokenshidden_dim: the hidden dimension used inside AFT Fulls: the window size used for AFT-Local in the paperNumber of heads is 1 as done in the paper'''self.dim = dimself.hidden_dim = hidden_dimself.to_q = nn.Linear(dim, hidden_dim)self.to_k = nn.Linear(dim, hidden_dim)self.to_v = nn.Linear(dim, hidden_dim)self.project = nn.Linear(hidden_dim, dim)self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))self.max_seqlen = max_seqlenself.s = snn.init.xavier_uniform_(self.wbias)def forward(self, x):B, T, _ = x.shapeQ = self.to_q(x).view(B, T, self.hidden_dim)K = self.to_k(x).view(B, T, self.hidden_dim)V = self.to_v(x).view(B, T, self.hidden_dim)self.wbias = nn.Parameter(torch.Tensor([[self.wbias[i][j] if math.fabs(i-j) < self.s else 0 for j in range(self.max_seqlen)] for i in range(self.max_seqlen)]))temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length'''From the paper'''Q_sig = torch.sigmoid(Q)temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))Yt = torch.mul(Q_sig, weighted)Yt = Yt.view(B, T, self.hidden_dim)Yt = self.project(Yt)return Yt

五、Outlook Attention

论文:VOLO: Vision Outlooker for Visual Recognition

如下图:

代码如下(代码连接):

import torch
import torch.nn as nnclass OutlookAttention(nn.Module):"""Implementation of outlook attention--dim: hidden dim--num_heads: number of heads--kernel_size: kernel size in each window for outlook attentionreturn: token features after outlook attention"""def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1,qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()head_dim = dim // num_headsself.num_heads = num_headsself.kernel_size = kernel_sizeself.padding = paddingself.stride = strideself.scale = qk_scale or head_dim**-0.5self.v = nn.Linear(dim, dim, bias=qkv_bias)self.attn = nn.Linear(dim, kernel_size**4 * num_heads)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)def forward(self, x):B, H, W, C = x.shapev = self.v(x).permute(0, 3, 1, 2)  # B, C, H, Wh, w = math.ceil(H / self.stride), math.ceil(W / self.stride)v = self.unfold(v).reshape(B, self.num_heads, C // self.num_heads,self.kernel_size * self.kernel_size,h * w).permute(0, 1, 4, 3, 2)  # B,H,N,kxk,C/Hattn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)attn = self.attn(attn).reshape(B, h * w, self.num_heads, self.kernel_size * self.kernel_size,self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4)  # B,H,N,kxk,kxkattn = attn * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size,padding=self.padding, stride=self.stride)x = self.proj(x.permute(0, 2, 3, 1))x = self.proj_drop(x)return x

相关文章:

  • weston 1: 编译与运行傻瓜教程(补充)
  • 【C语言】函数递归--输出n的k次方
  • 【Hive】
  • 智慧工地防盗新手段:TSINGSEE青犀工地智能监控防盗系统方案
  • 我的创作纪念日--第128天
  • 接口测试 — 3.Httpbin介绍(请求调试工具)
  • 要求CHATGPT高质量回答的艺术:提示工程技术的完整指南—第 21 章:课程学习提示
  • C语言 编程题
  • Qt对excel操作
  • 『亚马逊云科技产品测评』活动征文|AWS云服务器EC2实例实现ByConity快速部署
  • 设计表单表格组件
  • quartz实现动态任务管理系统
  • Python: any()函数
  • 现代雷达车载应用——第2章 汽车雷达系统原理 2.4节 雷达波形和信号处理
  • 设计模式-外观模式
  • JS中 map, filter, some, every, forEach, for in, for of 用法总结
  • AHK 中 = 和 == 等比较运算符的用法
  • HashMap剖析之内部结构
  • Java 9 被无情抛弃,Java 8 直接升级到 Java 10!!
  • javascript 总结(常用工具类的封装)
  • js ES6 求数组的交集,并集,还有差集
  • JS正则表达式精简教程(JavaScript RegExp 对象)
  • React16时代,该用什么姿势写 React ?
  • react-native 安卓真机环境搭建
  • VirtualBox 安装过程中出现 Running VMs found 错误的解决过程
  • 前端之React实战:创建跨平台的项目架构
  • 区块链共识机制优缺点对比都是什么
  • 微信小程序:实现悬浮返回和分享按钮
  • 项目实战-Api的解决方案
  • 小而合理的前端理论:rscss和rsjs
  • Salesforce和SAP Netweaver里数据库表的元数据设计
  • 阿里云IoT边缘计算助力企业零改造实现远程运维 ...
  • 进程与线程(三)——进程/线程间通信
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • ​软考-高级-系统架构设计师教程(清华第2版)【第12章 信息系统架构设计理论与实践(P420~465)-思维导图】​
  • # MySQL server 层和存储引擎层是怎么交互数据的?
  • #HarmonyOS:Web组件的使用
  • #我与Java虚拟机的故事#连载13:有这本书就够了
  • (1)(1.13) SiK无线电高级配置(五)
  • (AngularJS)Angular 控制器之间通信初探
  • (html5)在移动端input输入搜索项后 输入法下面为什么不想百度那样出现前往? 而我的出现的是换行...
  • (Redis使用系列) Springboot 在redis中使用BloomFilter布隆过滤器机制 六
  • (ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY)讲解
  • (第二周)效能测试
  • (二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)
  • (附源码)计算机毕业设计ssm本地美食推荐平台
  • (论文阅读23/100)Hierarchical Convolutional Features for Visual Tracking
  • (删)Java线程同步实现一:synchronzied和wait()/notify()
  • (四)搭建容器云管理平台笔记—安装ETCD(不使用证书)
  • (五)IO流之ByteArrayInput/OutputStream
  • . ./ bash dash source 这五种执行shell脚本方式 区别
  • .net mvc actionresult 返回字符串_.NET架构师知识普及
  • .NET 线程 Thread 进程 Process、线程池 pool、Invoke、begininvoke、异步回调
  • .NET 中各种混淆(Obfuscation)的含义、原理、实际效果和不同级别的差异(使用 SmartAssembly)
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地中转一个自定义的弱事件(可让任意 CLR 事件成为弱事件)