当前位置: 首页 > news >正文

Talking-Heads Attention

paper:Talking-Heads Attention

在CaiT这篇文章中,作用采用了talking-heads attention,这里做一下解释。

在原始multi-head self-attention中,各个head的计算是独立进行的,多个head的输出最后concat到一起,然后再经过一个线性变换得到最终的输出。

本文提出了在softmax操作的前后引入跨注意力头维度的线性变换,从而使每个self-attention函数依赖于所有的key和query。

下面分别是timm中普通Attention和TalkingHeadAttention的实现

# class Attention
def forward(self, x: torch.Tensor) -> torch.Tensor:  # (1,197,192)B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)# (1,197,576)->(1,197,3,3,64)->(3,1,3,197,64), (3, batch_size, num_heads, seq_len, head_dim), 3表示qkvq, k, v = qkv.unbind(0)  # (1,3,197,64)q, k = self.q_norm(q), self.k_norm(k)if self.fused_attn:  # Falsex = F.scaled_dot_product_attention(q, k, v,dropout_p=self.attn_drop.p if self.training else 0.,)else:# attn=softmax(qk)q = q * self.scaleattn = q @ k.transpose(-2, -1)  # (1,3,197,197)attn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = attn @ v  # (1,3,197,64)x = x.transpose(1, 2).reshape(B, N, C)  # (1,197,3,64)->(1,197,192)x = self.proj(x)  # (1,197,192)x = self.proj_drop(x)return x# class TalkingHeadAttn
def forward(self, x):B, N, C = x.shape  # (1,196,384)qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (1,196,1152)->(1,196,3,8,48)->(3,1,8,196,48)q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]  # (1,8,196,48)attn = q @ k.transpose(-2, -1)  # (1,8,196,196)attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,196,8)->(1,8,196,196)attn = attn.softmax(dim=-1)  # (1,8,196,196)attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,8,8)->(1,8,196,196)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (1,8,196,48)->(1,196,8,48)->(1,196,384)x = self.proj(x)  # (1,196,384)x = self.proj_drop(x)return x

从下图的对比看的更加清楚,左边是普通的attention,右边是talking-heads attention。左边的输入shape为(1, 197, 192),其中197=196+1是添加了class token,192是特征维度。右边的输入shape为(1, 196, 384),特征维度为384。左边num_heads=3,右边num_heads=8。因为左边的代码来自vision transformer,右边的代码来自CaiT,选择的具体模型variant不同,所以特征维度和head数量也不一样,但不影响。

可以看到,TalkingHeadAttention在计算softmax前后分别引入了一个线性变换self.proj_lself.proj_w,定义分别为self.proj_l = nn.Linear(num_heads, num_heads)self.proj_w = nn.Linear(num_heads, num_heads)。在线性变换前先对输入进行维度变换通过.permute(0, 2, 3 ,1)将num_head维度放到最后,因此线性变换是针对num_head维度的,从而实现跨head的交互,最后再permute回去。

相关文章:

  • Kotlin 中的 infix 关键字(中缀函数)
  • C# 集合(二) —— List/Queue类
  • 马斯克的Grok-1:开源AI模型的突破与挑战
  • TrueNAS系统在ARM平台上的移植
  • 傅佩荣教授讲座视频全集,傅佩荣讲座大全,傅佩荣国学讲座全集百度网盘
  • 使用同步和异步方式更新插入MongoDB数据的性能对比
  • 使用Scala爬取安居客房产信息并存入CSV文件
  • AI时代:硬件狂欢,软件落寞 华为开发者大会2024
  • 如何在 MySQL 中创建和使用事务?
  • 一文读懂数据仓库ODS层
  • 外贸SEO工具有哪些推荐?
  • Unity URP下通过相机让部分Render不受后处理渲染
  • 前端模糊搜索关键字高亮
  • Dubbo3 服务原生支持 http 访问,兼具高性能与易用性
  • android Switch/case with R.id.XXXX in android doesn‘t work 错误: 需要常量表达式解决方案
  • 【跃迁之路】【669天】程序员高效学习方法论探索系列(实验阶段426-2018.12.13)...
  • angular2开源库收集
  • Angularjs之国际化
  • js如何打印object对象
  • learning koa2.x
  • Phpstorm怎样批量删除空行?
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 开发了一款写作软件(OSX,Windows),附带Electron开发指南
  • 如何设计一个微型分布式架构?
  • 如何正确配置 Ubuntu 14.04 服务器?
  • 使用 Docker 部署 Spring Boot项目
  • 双管齐下,VMware的容器新战略
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 《码出高效》学习笔记与书中错误记录
  • ### Cause: com.mysql.jdbc.exceptions.jdbc4.MySQLTr
  • #常见电池型号介绍 常见电池尺寸是多少【详解】
  • #面试系列-腾讯后端一面
  • #数据结构 笔记三
  • $Django python中使用redis, django中使用(封装了),redis开启事务(管道)
  • (1)(1.19) TeraRanger One/EVO测距仪
  • (delphi11最新学习资料) Object Pascal 学习笔记---第13章第1节 (全局数据、栈和堆)
  • (LeetCode) T14. Longest Common Prefix
  • (二)springcloud实战之config配置中心
  • (附源码)ssm基于jsp高校选课系统 毕业设计 291627
  • (紀錄)[ASP.NET MVC][jQuery]-2 純手工打造屬於自己的 jQuery GridView (含完整程式碼下載)...
  • (蓝桥杯每日一题)love
  • (三十)Flask之wtforms库【剖析源码上篇】
  • (已更新)关于Visual Studio 2019安装时VS installer无法下载文件,进度条为0,显示网络有问题的解决办法
  • **PyTorch月学习计划 - 第一周;第6-7天: 自动梯度(Autograd)**
  • .NET CLR基本术语
  • .Net Winform开发笔记(一)
  • .NET/C#⾯试题汇总系列:集合、异常、泛型、LINQ、委托、EF!(完整版)
  • /deep/和 >>>以及 ::v-deep 三者的区别
  • [ JavaScript ] JSON方法
  • [ 云计算 | AWS ] 对比分析:Amazon SNS 与 SQS 消息服务的异同与选择
  • [15] 使用Opencv_CUDA 模块实现基本计算机视觉程序
  • [3D基础]理解计算机3D图形学中的坐标系变换
  • [Algorithm][动态规划][01背包问题][目标和][最后一块石头的重量Ⅱ]详细讲解
  • [Algorithm][综合训练][体育课测验(二)][合唱队形][宵暗的妖怪]详细讲解
  • [C#]无法获取源 https://api.nuge t.org/v3-index存储签名信息解决方法