当前位置: 首页 > news >正文

多头注意力用单元矩阵实现以及原因

一、 PyTorch 中实现这个过程。

**本质是在横向加Q1,Q2**

1. 创建单一的权重矩阵

假设我们有以下参数:

  • input_dim = 64:输入的维度。
  • num_heads = 8:注意力头的数量。
  • head_dim = 16:每个头的维度。

我们可以先创建一个大的权重矩阵 W_query,然后进行拆分。

import torch
import torch.nn as nn# 假设输入维度是 64,每个头的维度是 16,总共 8 个头
input_dim = 64
num_heads = 8
head_dim = 16# 创建一个大的 W_query 矩阵,形状为 (input_dim, num_heads * head_dim)
W_query = nn.Linear(input_dim, num_heads * head_dim)# 假设我们有一个输入 X,形状为 (batch_size, seq_len, input_dim)
batch_size = 32
seq_len = 10
X = torch.rand(batch_size, seq_len, input_dim)# 通过线性层得到 Q 矩阵,形状为 (batch_size, seq_len, num_heads * head_dim)
Q = W_query(X)  # Q 的形状为 (batch_size, seq_len, num_heads * head_dim)# 将 Q 拆分成多个头,每个头的形状为 (batch_size, seq_len, head_dim)
# 拆分后的形状为 (batch_size, seq_len, num_heads, head_dim)
Q = Q.view(batch_size, seq_len, num_heads, head_dim)# 如果需要将 Q 传递给每个头独立处理,可以对维度进行调整,形状为 (batch_size, num_heads, seq_len, head_dim)
Q = Q.permute(0, 2, 1, 3)

2. 拆分后的矩阵如何用于多头注意力机制

拆分后的 Q 矩阵可以直接用于每个注意力头的计算,其他的 KV 矩阵也可以类似地处理。

# 例如,假设我们有 K 和 V 矩阵
W_key = nn.Linear(input_dim, num_heads * head_dim)
W_value = nn.Linear(input_dim, num_heads * head_dim)K = W_key(X).view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
V = W_value(X).view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)# 计算每个头的注意力得分
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)  # 形状为 (batch_size, num_heads, seq_len, seq_len)# 对注意力得分进行 softmax 操作
attention_weights = torch.softmax(attention_scores, dim=-1)# 计算每个头的注意力输出
attention_output = torch.matmul(attention_weights, V)  # 形状为 (batch_size, num_heads, seq_len, head_dim)# 最后将所有头的输出拼接在一起,形成最终的输出
attention_output = attention_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, num_heads * head_dim)

3、小结

  1. 创建单一权重矩阵:我们创建一个大矩阵,其输出维度是 num_heads * head_dim
  2. 拆分矩阵:通过 .view.permute 方法,将这个大矩阵的输出拆分为多个小矩阵,每个小矩阵对应一个注意力头。
  3. 计算注意力:每个头独立计算注意力得分,然后将结果合并。

你提出的问题实际上涉及到参数共享的误解。让我更详细地解释这个问题,澄清在创建独立矩阵和共享矩阵时参数量的实际情况。

二、 参数量

无论是创建独立矩阵还是共享矩阵,参数量都是相同的。每种方法的参数量计算结果都是 8192。

假设:

  • input_dim = 64:输入的维度。
  • num_heads = 8:注意力头的数量。
  • head_dim = 16:每个头的维度。
1. 独立矩阵的情况

如果为每个注意力头创建独立的权重矩阵,则每个头的权重矩阵的参数量为 input_dim * head_dim。对于 num_heads 个头,整个参数量为:

总参数量 = num_heads × input_dim × head_dim \text{总参数量} = \text{num\_heads} \times \text{input\_dim} \times \text{head\_dim} 总参数量=num_heads×input_dim×head_dim

根据假设,参数量计算如下:

总参数量 = 8 × 64 × 16 = 8192 \text{总参数量} = 8 \times 64 \times 16 = 8192 总参数量=8×64×16=8192

2. 共享矩阵的情况

如果我们使用一个大的共享矩阵,并将其拆分为多个头,则共享矩阵的形状为 (input_dim, num_heads * head_dim)。因此,整个矩阵的参数量为:

总参数量 = input_dim × ( num_heads × head_dim ) \text{总参数量} = \text{input\_dim} \times (\text{num\_heads} \times \text{head\_dim}) 总参数量=input_dim×(num_heads×head_dim)

根据假设,参数量计算如下:

总参数量 = 64 × ( 8 × 16 ) = 64 × 128 = 8192 \text{总参数量} = 64 \times (8 \times 16) = 64 \times 128 = 8192 总参数量=64×(8×16)=64×128=8192

三、 为什么使用共享矩阵

  1. 代码简洁性:使用共享矩阵可以简化代码,实现统一的矩阵运算,减少了手动管理多个独立矩阵的复杂性。

  2. 计算效率:共享矩阵可以利用并行计算的优势,使得在进行矩阵运算时更加高效,因为只需一次大的矩阵乘法操作,而不需要为每个头分别计算。

  3. 实现上的一致性:共享矩阵的实现方式与多头注意力机制的逻辑一致(即所有头都是从同一个大的线性变换中分离出来的),更符合框架的设计和优化原则。

为了帮助你更直观地理解共享矩阵与独立矩阵的关系,我会举一个具体的数值例子,说明共享矩阵在进行一次性矩阵乘法后分解成多个独立矩阵的结果与独立矩阵逐一计算的结果是相同的。

四 、数值示例直观说明是结果是一致的

设定参数

假设我们有以下参数:

  • input_dim = 4:输入的维度。
  • num_heads = 2:注意力头的数量。
  • head_dim = 2:每个头的维度。
  • 输入矩阵 X 的形状为 (batch_size=1, seq_len=3, input_dim=4)
输入矩阵 X

X = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] X = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} X= 159261037114812

独立矩阵的权重

我们为每个头单独创建 W_query_1W_query_2 矩阵:
W _ q u e r y _ 1 = [ 1 0 0 1 1 0 0 1 ] , W _ q u e r y _ 2 = [ 2 1 1 2 2 1 1 2 ] W\_query\_1 = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad W\_query\_2 = \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} W_query_1= 10100101 ,W_query_2= 21211212

1. 独立矩阵逐一计算

我们对每个头的 Q 进行计算:

  • 对第一个头计算 Q_1
    Q 1 = X × W _ q u e r y _ 1 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 0 1 1 0 0 1 ] = [ 4 6 12 14 20 22 ] Q_1 = X \times W\_query\_1 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} Q1=X×W_query_1= 159261037114812 × 10100101 = 4122061422

  • 对第二个头计算 Q_2
    Q 2 = X × W _ q u e r y _ 2 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 2 1 1 2 2 1 1 2 ] = [ 10 16 30 40 50 64 ] Q_2 = X \times W\_query\_2 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} Q2=X×W_query_2= 159261037114812 × 21211212 = 103050164064

2. 共享矩阵一次性计算

我们将上述两个矩阵组合为一个大的共享矩阵 W_query,其形状为 (input_dim=4, num_heads * head_dim=4)
W _ q u e r y = [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] W\_query = \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} W_query= 1010010121211212

然后一次性计算出所有头的 Q 矩阵:
Q = X × W _ q u e r y = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] Q = X \times W\_query = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} Q=X×W_query= 159261037114812 × 1010010121211212 = 4122061422103050164064

3. 将共享矩阵结果拆分为独立矩阵

最后,我们将 Q 矩阵按列拆分成两个子矩阵,每个子矩阵对应一个注意力头:

  • 第一个头的结果 Q_1
    Q 1 = [ 4 6 12 14 20 22 ] Q_1 = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} Q1= 4122061422

  • 第二个头的结果 Q_2
    Q 2 = [ 10 16 30 40 50 64 ] Q_2 = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} Q2= 103050164064

4. 对比结果

从上述步骤可以看到,使用共享矩阵一次性计算并拆分后的结果与使用独立矩阵逐一计算的结果完全相同。具体来说:

  • Q_1 的结果在独立和共享情况下都为:
    [ 4 6 12 14 20 22 ] \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} 4122061422

  • Q_2 的结果在独立和共享情况下都为:
    [ 10 16 30 40 50 64 ] \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} 103050164064

计算 QK 的点积,即 Q * K^T,以展示共享矩阵和独立矩阵在这一步的等价性。

假设

继续沿用之前的参数和矩阵:

  • 输入矩阵 X 的形状为 (batch_size=1, seq_len=3, input_dim=4)
  • W_queryW_key 的设置方式与之前相同,分为独立矩阵和共享矩阵两种情况。

1. 独立矩阵的计算

假设我们有独立的 W_key_1W_key_2

W _ k e y _ 1 = W _ q u e r y _ 1 = [ 1 0 0 1 1 0 0 1 ] , W _ k e y _ 2 = W _ q u e r y _ 2 = [ 2 1 1 2 2 1 1 2 ] W\_key\_1 = W\_query\_1 = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad W\_key\_2 = W\_query\_2 = \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} W_key_1=W_query_1= 10100101 ,W_key_2=W_query_2= 21211212

首先计算独立的 K_1K_2

  • 对第一个头计算 K_1
    K 1 = X × W _ k e y _ 1 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 0 1 1 0 0 1 ] = [ 4 6 12 14 20 22 ] K_1 = X \times W\_key\_1 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} K1=X×W_key_1= 159261037114812 × 10100101 = 4122061422

  • 对第二个头计算 K_2
    K 2 = X × W _ k e y _ 2 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 2 1 1 2 2 1 1 2 ] = [ 10 16 30 40 50 64 ] K_2 = X \times W\_key\_2 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} K2=X×W_key_2= 159261037114812 × 21211212 = 103050164064

然后计算 Q_1 * K_1^TQ_2 * K_2^T

  • Q_1 * K_1^T
    Q 1 × K 1 T = [ 4 6 12 14 20 22 ] × [ 4 12 20 6 14 22 ] = [ 52 120 188 120 292 464 188 464 740 ] Q_1 \times K_1^T = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} \times \begin{bmatrix} 4 & 12 & 20 \\ 6 & 14 & 22 \end{bmatrix} = \begin{bmatrix} 52 & 120 & 188 \\ 120 & 292 & 464 \\ 188 & 464 & 740 \end{bmatrix} Q1×K1T= 4122061422 ×[4612142022]= 52120188120292464188464740

  • Q_2 * K_2^T
    Q 2 × K 2 T = [ 10 16 30 40 50 64 ] × [ 10 30 50 16 40 64 ] = [ 356 820 1284 820 1960 3100 1284 3100 4916 ] Q_2 \times K_2^T = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} \times \begin{bmatrix} 10 & 30 & 50 \\ 16 & 40 & 64 \end{bmatrix} = \begin{bmatrix} 356 & 820 & 1284 \\ 820 & 1960 & 3100 \\ 1284 & 3100 & 4916 \end{bmatrix} Q2×K2T= 103050164064 ×[101630405064]= 356820128482019603100128431004916

2. 共享矩阵的计算

使用共享矩阵时,先计算出 QK

Q = X × W _ q u e r y = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] Q = X \times W\_query = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} Q=X×W_query= 159261037114812 × 1010010121211212 = 4122061422103050164064

K 的计算与 Q 相同,因为 W_key = W_query

接下来计算 Q * K^T

Q × K T = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] × [ 4 12 20 6 14 22 10 30 50 16 40 64 ] = [ 52 120 188 356 820 1284 120 292 464 820 1960 3100 188 464 740 1284 3100 4916 ] Q \times K^T = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} \times \begin{bmatrix} 4 & 12 & 20 \\ 6 & 14 & 22 \\ 10 & 30 & 50 \\ 16 & 40 & 64 \end{bmatrix} = \begin{bmatrix} 52 & 120 & 188 & 356 & 820 & 1284 \\ 120 & 292 & 464 & 820 & 1960 & 3100 \\ 188 & 464 & 740 & 1284 & 3100 & 4916 \end{bmatrix} Q×KT= 4122061422103050164064 × 4610161214304020225064 = 52120188120292464188464740356820128482019603100128431004916

3. 将共享矩阵结果拆分为独立矩阵结果

我们可以将结果矩阵按头的数量拆分成两个部分:

  • 前两列对应 Q_1 * K_1^T
    [ 52 120 188 120 292 464 188 464 740 ] \begin{bmatrix} 52 & 120 & 188 \\ 120 & 292 & 464 \\ 188 & 464 & 740 \end{bmatrix} 52120188120292464188464740

  • 后四列对应 Q_2 * K_2^T
    [ 356 820 1284 820 1960 3100 1284 3100 4916 ] \begin{bmatrix} 356 & 820 & 1284 \\ 820 & 1960 & 3100 \\ 1284 & 3100 & 4916 \end{bmatrix} 356820128482019603100128431004916

4. 对比结果

从上述结果可以看到,使用共享矩阵一次性计算并拆分后的 Q * K^T 结果与使用独立矩阵逐一计算的结果完全一致:

  • Q_1 * K_1^T 在共享和独立情况下的结果一致。
  • Q_2 * K_2^T 在共享和独立情况下的结果一致。

总结

验证了共享矩阵和独立矩阵在计算 Q * K^T 时的等价性。虽然操作顺序和形式不同,但它们最终得到的结果是完全相同的。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • SO_REUSEADDR 和 SO_REUSEPORT 的区别
  • WEB渗透-未授权访问篇
  • 图像处理案例03
  • 深度学习与图像修复:ADetailer插件在Stable Diffusion中的应用
  • 【JavaEE初阶】JUC(java.uitl.concurrent)的常见类
  • Java Server-Sent Event 服务端发送事件
  • [FSCTF 2023]细狗2.0
  • 力扣热题100_二叉树_94_二叉树的中序遍历
  • C语言中常用的函数
  • python自动化笔记:excel文件处理及日志收集
  • 列式数据库(HBase) 中实现表与表的关联
  • 区块链(Blockchain)
  • 【代码随想录】长度最小的子数组——滑动窗口
  • 第二十一节、敌人追击状态的转换
  • 【K8S】K8S架构及相关组件
  • 【407天】跃迁之路——程序员高效学习方法论探索系列(实验阶段164-2018.03.19)...
  • CAP理论的例子讲解
  • dva中组件的懒加载
  • gf框架之分页模块(五) - 自定义分页
  • github从入门到放弃(1)
  • Joomla 2.x, 3.x useful code cheatsheet
  • LintCode 31. partitionArray 数组划分
  • opencv python Meanshift 和 Camshift
  • Python3爬取英雄联盟英雄皮肤大图
  • React 快速上手 - 07 前端路由 react-router
  • RxJS 实现摩斯密码(Morse) 【内附脑图】
  • 阿里中间件开源组件:Sentinel 0.2.0正式发布
  • 力扣(LeetCode)357
  • 你真的知道 == 和 equals 的区别吗?
  • 嵌入式文件系统
  • 深度学习入门:10门免费线上课程推荐
  • 跳前端坑前,先看看这个!!
  • 问题之ssh中Host key verification failed的解决
  • 我从编程教室毕业
  • 移动端 h5开发相关内容总结(三)
  • 《天龙八部3D》Unity技术方案揭秘
  • 测评:对于写作的人来说,Markdown是你最好的朋友 ...
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • ​低代码平台的核心价值与优势
  • #Ubuntu(修改root信息)
  • %3cli%3e连接html页面,html+canvas实现屏幕截取
  • ( )的作用是将计算机中的信息传送给用户,计算机应用基础 吉大15春学期《计算机应用基础》在线作业二及答案...
  • (02)Unity使用在线AI大模型(调用Python)
  • (1)(1.13) SiK无线电高级配置(六)
  • (145)光线追踪距离场柔和阴影
  • (173)FPGA约束:单周期时序分析或默认时序分析
  • (26)4.7 字符函数和字符串函数
  • (39)STM32——FLASH闪存
  • (windows2012共享文件夹和防火墙设置
  • (亲测有效)推荐2024最新的免费漫画软件app,无广告,聚合全网资源!
  • (三)Honghu Cloud云架构一定时调度平台
  • (原創) 系統分析和系統設計有什麼差別? (OO)
  • .NET Windows:删除文件夹后立即判断,有可能依然存在
  • .NET/C#⾯试题汇总系列:集合、异常、泛型、LINQ、委托、EF!(完整版)
  • .NET导入Excel数据