当前位置: 首页 > news >正文

DeformableAttention的原理解读和源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

目录

  • 原理
    • 第一步看看输入:
    • 第二步,准备工作:
      • 生成参考点的偏移量
      • 生成参考点的权重
      • 生成参考点
    • 第三步,工作:
  • 源码

原理

目前流行3D转2DBEV方案的都绕不开的transfomer变体-DeformableAttention.
在这里插入图片描述
传统transformer注意力关注全局特征,速度慢.而DeformableAttention注意力模块只关注一个目标周围的一小部分的关键采样点特征.原来的DETR需要很多个 epoch 才能找到特征,在Deformable DTER中可以更快,据说1/10的耗时。
原理:以DETR3D的做法为例.

第一步看看输入:

定义一个shape为(900,256)的query,代表900和目标,每个目标256维查询信息.
定义一个query_pos shape同query.
定义一个shape为(900,3)的reference_points,作为目标参考点.
输入为:pts_feats(1,43054,256),多尺度flatten结果,
多尺度特征图尺寸记录:spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])
特征图在pts_feats起点记录:level_start_index:([ 0, 32400, 40500, 42525])
可自行验算下.

第二步,准备工作:

pts_feats reshape为(1,43054,8,32)

value = value.view(bs, num_value, self.num_heads, -1)

生成参考点的偏移量

query经过self.sampling_offsets线性映射再reshape输出:
sampling_offsets(torch.Size([1, 900, 8, 4, 4, 2]))
其中8是多头数量,4是特征层数, 4是采样点数, 2是采样点xy两个维度.意思是8次在4层特征图上分别采样4个点,这844个点的xy方向的偏移量.

生成参考点的权重

query经过self.attention_weights线性映射再reshape输出:
attention_weights(torch.Size([1, 900, 8, 4, 4]))
对应上述点的权重.

生成参考点

reference_points加上参考点的偏移量生成,真正的参考点.

sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets

sampling_locations(torch.Size([1, 900, 8, 4, 4, 2]))

说白就是,就是定义一个query_embed,它生成自己即将要去采样的点位置和采样点权重.

第三步,工作:

输入:
value shape(torch.Size([b,43054,8,32]))
sampling_locations(torch.Size([b, 900, 8, 4, 4, 2]))
attention_weights(torch.Size([b, 900, 8, 4, 4]))
spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])

value 根据spatial_shapes分解出各个level:
[torch.Size([b,180180,8,32],torch.Size([b,9090,8,32])),torch.Size([b,4545,8,32])),torch.Size([b,2323,8,32]))]
reshape为正常图像torch.Size([b*8,32,180,180]

sampling_locations原本为采样点位置,范围为[0,1),为了适应F.grid_sample采样函数的用法,调整为[-1,1)分布,
调用F.grid_sample对每一层特征进行采样,输入value为torch.Size([b8,32,level_h,level_w]),采样点为sampling_grid:torch.Size([b8,900,4,2])
则输出为sampling_value:torch.Size([b8,32,900,4])
意思是,900个query在特征图(32,level_h,level_w)中各采样4个点,采样结果为900个对应的4个通道为32的像素特征.
将4层采样结果sampling_value拍在一起torch.Size([b
8,32,900,4*4])

attention_weights变成相同形式(torch.Size([b8, 1,900, 44])),然后对16个采样特征进行加权求和输出outputtorch.Size([b,32*8,900]).后续交给FFN对多头特征进行全连接融合.

源码

import torch
import torch.nn.functional as F
import torch.nn as nndef multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights):batch, _, num_head, embeding_dim_perhead = value.shape_, query_size, _, level_num, sample_num, _ = sampling_locations.shapesplit_list = []for h, w in spatial_shapes:split_list.append(int(h * w))value_list = value.split(split_size=tuple(split_list), dim=1)# [0,1)分布变成 [-1,1)分布,因为要调用F.grid_sample函数sampling_grid = 2 * sampling_locations - 1output_list = []for level_id, (h, w) in enumerate(spatial_shapes):h = int(h)w = int(w)# batch, value_len, num_head, embeding_dim_perhead# batch, num_head, embeding_dim_perhead, value_len# batch*num_head, embeding_dim_perhead, h, wvalue_l = value_list[level_id].permute(0, 2, 3, 1).view(batch * num_head, embeding_dim_perhead, h, w)# batch,query_size,num_head,level_num,sample_num,2# batch,query_size,num_head,sample_num,2# batch,num_head,query_size,sample_num,2# batch*num_head,query_size,sample_num,2sampling_grid_l = sampling_grid[:, :, :, level_id, :, :].permute(0, 2, 1, 3, 4).view(batch * num_head,query_size, sample_num, 2)# batch*num_head embeding_dim,,query_size, sample_numoutput = F.grid_sample(input=value_l,grid=sampling_grid_l,mode='bilinear',padding_mode='zeros',align_corners=False)output_list.append(output)# batch*num_head, embeding_dim_perhead,query_size, level_num, sample_numoutputs = torch.stack(output_list, dim=-2)# batch,query_size,num_head,level_num,sample_num# batch,num_head,query_size,level_num,sample_num# batch*num_head,1,query_size,level_num,sample_numattention_weights = attention_weights.permute(0, 2, 1, 3, 4).view(batch * num_head, 1, query_size, level_num,sample_num)outputs = outputs * attention_weights# batch*num_head, embeding_dim_perhead,query_size# batch,num_head, embeding_dim_perhead,query_size# batch,query_size,num_head, embeding_dim_perhead# batch,query_size,num_head*embeding_dim_perheadoutputs = outputs.sum(-1).sum(-1).view(batch, num_head, embeding_dim_perhead, query_size).permute(0, 3, 1, 2). \view(batch, query_size, num_head * embeding_dim_perhead)return outputs.contiguous()if __name__ == '__main__':batch = 1num_head = 8embeding_dim = 256query_size = 900spatial_shapes = torch.Tensor([[180, 180], [90, 90], [45, 45], [23, 23]])value_len = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().int()value = torch.rand(size=(batch, value_len, embeding_dim))query_embeding = torch.rand(size=(batch, query_size, embeding_dim * 2 + 3))query = query_embeding[..., :embeding_dim]query_pos = query_embeding[..., embeding_dim:2 * embeding_dim]reference_poins = query_embeding[..., 2 * embeding_dim:]# 讨论1:在deformale-att中这个query并不会和value交互生成att-weights,att-weights只和query有关,# 也就是推理过程att-weights(包括sampling_locations)是固定的.# 据作者解释这是因为采用前者的方式计算的attention权重存在退化问题,# 即最后得到的attention权重与并没有随key的变化而变化。# 因此,这两种计算attention权重的方式最终得到的结果相当,# 而后者耗时更短、计算代价更小,所以作者选择直接对query做projection得到attention权重。# 讨论2:在query固定情况下,第一个layer的att-weights无法改变,# 但是第二个layer的query与value有关,att-weights则会发生变化.so the self-att in frist layer is not nesscerarylevel_num = 4sample_num = 4sampling_offsets_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num * 2)sampling_offsets = sampling_offsets_net(query).view(batch, query_size, num_head, level_num, sample_num, 2)sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsetsattention_weights_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num)attention_weights = attention_weights_net(query).view(batch, query_size, num_head, level_num * sample_num)attention_weights = attention_weights.softmax(dim=-1).view(batch, query_size, num_head, level_num,sample_num)  # sum of 16 points weight is equal to 1embeding_dim_perhead = embeding_dim // num_headvalue = value.view(batch, value_len, num_head, -1)output = multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_location, attention_weights)pass

如需获取全套代码请参考

相关文章:

  • QML与C++通信
  • Python电梯楼层数字识别
  • STM32第九节(中级篇):RCC(第一节)——时钟树讲解
  • Tomcat的部署及调优,jvm调优
  • Java8 新特性
  • Java-并发编程--ThreadLocal、InheritableThreadLocal
  • 《LeetCode热题100》笔记题解思路技巧优化_Part_3
  • QT 状态机的使用
  • RocketMQ架构详解
  • 17个工作必备的Python自动化代码分享(上篇)
  • 泽众云真机-机型支持ADB调试功能即将上线
  • 进程学习--02
  • 有来团队后台项目-解析7
  • 鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Web)中篇
  • Postman请求API接口测试步骤和说明
  • [微信小程序] 使用ES6特性Class后出现编译异常
  • 【162天】黑马程序员27天视频学习笔记【Day02-上】
  • Babel配置的不完全指南
  • Java 23种设计模式 之单例模式 7种实现方式
  • jdbc就是这么简单
  • Node 版本管理
  • PhantomJS 安装
  • React Transition Group -- Transition 组件
  • Shadow DOM 内部构造及如何构建独立组件
  • SpringBoot 实战 (三) | 配置文件详解
  • 浏览器缓存机制分析
  • 体验javascript之美-第五课 匿名函数自执行和闭包是一回事儿吗?
  • 王永庆:技术创新改变教育未来
  • 一个SAP顾问在美国的这些年
  • 一加3T解锁OEM、刷入TWRP、第三方ROM以及ROOT
  • 摩拜创始人胡玮炜也彻底离开了,共享单车行业还有未来吗? ...
  • ​【原创】基于SSM的酒店预约管理系统(酒店管理系统毕业设计)
  • (pt可视化)利用torch的make_grid进行张量可视化
  • (附源码)node.js知识分享网站 毕业设计 202038
  • (附源码)springboot优课在线教学系统 毕业设计 081251
  • (附源码)ssm码农论坛 毕业设计 231126
  • (黑客游戏)HackTheGame1.21 过关攻略
  • (剑指Offer)面试题34:丑数
  • (未解决)jmeter报错之“请在微信客户端打开链接”
  • (转)AS3正则:元子符,元序列,标志,数量表达符
  • (转)Linux NTP配置详解 (Network Time Protocol)
  • (转)memcache、redis缓存
  • ***利用Ms05002溢出找“肉鸡
  • *p++,*(p++),*++p,(*p)++区别?
  • .[hudsonL@cock.li].mkp勒索加密数据库完美恢复---惜分飞
  • .MSSQLSERVER 导入导出 命令集--堪称经典,值得借鉴!
  • .NET DevOps 接入指南 | 1. GitLab 安装
  • .net oracle 连接超时_Mysql连接数据库异常汇总【必收藏】
  • .NET 使用 ILRepack 合并多个程序集(替代 ILMerge),避免引入额外的依赖
  • .net反编译工具
  • @Autowired 与@Resource的区别
  • [ vulhub漏洞复现篇 ] GhostScript 沙箱绕过(任意命令执行)漏洞CVE-2019-6116
  • [.net]官方水晶报表的使用以演示下载
  • [2023-年度总结]凡是过往,皆为序章
  • [51nod1610]路径计数