当前位置: 首页 > news >正文

注意力机制的细节

文章目录

    • 注意力机制
      • 举个栗子
      • 从举例引出注意力
      • 注意力机制的构成要素
    • 自注意力机制
    • 自注意力的细节
      • 输入词嵌入
      • 查询与键值
      • 注意力评分
      • 注意力权重
      • 掩蔽注意力
      • 注意力汇聚
      • 多头注意力

我们在面对庞杂的信息时能有选择地关注特定部分从而忽略其他部分,并且在不同场合时间关注的特定部分也可以不同,因此能够分清轻重缓急,达到高效整合、提升效率等目的,只关注部分信息的能力对人类的进化也更加有意义,注意力机制在某种意义上借鉴了人类注意力的特点。

注意力机制对近几年的深度学习产生了深远影响,从谷歌的《Attention is all you need》开始,注意力机制广泛应用在计算机视觉和自然语言处理等领域。

本文首先引入注意力机制,然后从运作过程与计算逻辑等方面讲解自注意力。自注意力是注意力的特殊呈现,本质上是一个东西,它是Transformer等模型架构的核心模块,其重要性不言而喻。

注意力机制

我们先从一个例子开始,由浅入深,由表及里地介绍注意力机制的组成部件及其运作过程。

举个栗子

例:开学班级票选班长,有我、小孙、小明、小红和小花5人参与匿名投票,每人可投100张票,结束后我共获投100票。现在我想估计这些票中分别有多少来自他们4人。一个简单粗暴的方法是:从 这个主体出发,将注意力放在 审视我与每个人的亲密程度 上面,然后 以得到的亲密度作为衡量标准估计票数

现假设我得到了与每个人的亲密度如下(亲密度 a a a 的值用数字1-5表示,越大表示关系越好,可能投的票数越多):

a 我 , 小孙 = 4 、 a 我 , 小明 = 3 、 a 我 , 小红 = 2 、 a 我 , 小花 = 1 a_{我,小孙}=4 、 a_{我,小明}=3 、a_{我,小红}=2 、a_{我,小花}=1 a,小孙=4a,小明=3a,小红=2a,小花=1

亲密度不好直接量化票数,因此可对其做 s o f t m a x softmax softmax 处理以获得 与亲密度对应的票数的概率分布 ,该分布是各项非负且和为1的权重:

α 我 , 小孙 = 0.64 、 α 我 , 小明 = 0.24 、 α 我 , 小红 = 0.09 、 α 我 , 小花 = 0.03 \alpha_{我,小孙}=0.64 、\alpha_{我,小明}=0.24 、\alpha_{我,小红}=0.09 、 \alpha_{我,小花}=0.03 α,小孙=0.64α,小明=0.24α,小红=0.09α,小花=0.03

按上述权重,可以估算出100张票分别来自4人的票数,比如因为我和小孙最亲密,因此理所当然他投给我的票最多。

另外,其他人也会参与竞选获得投票,比如小明作为主体,依此方法可以估计他所获得的来自其他4人的票数。

计算过程如图1所示:


在这里插入图片描述

图1

从举例引出注意力

可以令:

主体为 q q q 我 = q 我 我=q^我 =q

亲密程度为 k k k 小孙 = k 1 、小明 = k 2 、小红 = k 3 、小花 = k 4 小孙=k_1 、 小明=k_2 、 小红=k_3 、 小花=k_4 小孙=k1、小明=k2、小红=k3、小花=k4 ,比如我与小孙的亲密度: a 我 , 小孙 = a ( q 我 , k 1 ) a_{我,小孙}=a(q^我,k_1) a,小孙=a(q,k1) ,权重为 α 我 , 小孙 = α ( q 我 , k 1 ) \alpha_{我,小孙}=\alpha(q^我,k_1) α,小孙=α(q,k1)

持有票数为 v v v 小孙 = v 1 、小明 = v 2 、小红 = v 3 、小花 = v 4 小孙=v_1 、 小明=v_2 、 小红=v_3 、 小花=v_4 小孙=v1、小明=v2、小红=v3、小花=v4

则所得票数 f f f 可表示为:

f ( q 我 ) = f ( q 我 , ( k 1 , v 1 ) , . . . , ( k n , v n ) ) = α ( q 我 , k 1 ) v 1 + α ( q 我 , k 2 ) v 2 + α ( q 我 , k 3 ) v 3 + α ( q 我 , k 4 ) v 4 = ∑ i = 1 n α ( q 我 , k i ) v i ( 1 ) f(q^我)=f(q^我,(k_1,v_1),...,(k_n,v_n))=\alpha(q^我,k_1)v_1+\alpha(q^我,k_2)v_2+\alpha(q^我,k_3)v_3+\alpha(q^我,k_4)v_4=\sum_{i=1}^{n} \alpha(q^我,k_i)v_i \ \ \ \ \ (1) f(q)=f(q,(k1,v1),...,(kn,vn))=α(q,k1)v1+α(q,k2)v2+α(q,k3)v3+α(q,k4)v4=i=1nα(q,ki)vi     (1)

其中 n = 4 n=4 n=4 α ( q 我 , k i ) = s o f t m a x ( a ( q 我 , k i ) ) = e x p ( a ( q 我 , k i ) ) ∑ j = 1 n e x p ( a ( q 我 , k j ) ) ( 2 ) \alpha(q^我,k_i)=softmax(a(q^我,k_i))={exp(a(q^我,k_i)) \over{\sum_{j=1}^n} exp(a(q^我,k_j))} \ \ \ \ \ (2) α(q,ki)=softmax(a(q,ki))=j=1nexp(a(q,kj))exp(a(q,ki))     (2)


在这里插入图片描述

图2

注意力机制的构成要素

从上式 ( 1 ) 、 ( 2 ) (1)、(2) (1)(2) 可抽象出 ( 3 ) 、 ( 4 ) (3)、(4) (3)(4)

f ( q ) = f ( q , ( k 1 , v 1 ) , ( k 2 , v 2 ) , . . . , ( k n , v n ) ) = ∑ i = 1 n α ( q , k i ) v i ( 3 ) f(q)=f(q,(k_1,v_1),(k_2,v_2),...,(k_n,v_n))=\sum_{i=1}^n \alpha(q,k_i)v_i \ \ \ \ \ (3) f(q)=f(q,(k1,v1),(k2,v2),...,(kn,vn))=i=1nα(q,ki)vi     (3)

( 3 ) (3) (3) 等号最右边其实是一个加权平均,引入注意力机制后称为 注意力汇聚 。 其中 q q q 称为 查询 ( k , v ) (k,v) (k,v) 称为 键值对 ,一般成对出现。

α ( q , k i ) = s o f t m a x ( a ( q , k i ) ) = e x p ( a ( q , k i ) ) ∑ j = 1 n e x p ( a ( q , k j ) ) ( 4 ) \alpha(q,k_i)=softmax(a(q,k_i))={exp(a(q,k_i)) \over{\sum_{j=1}^n} exp(a(q,k_j))} \ \ \ \ \ (4) α(q,ki)=softmax(a(q,ki))=j=1nexp(a(q,kj))exp(a(q,ki))     (4)

( 4 ) (4) (4) α ( q , k i ) \alpha(q,k_i) α(q,ki) 是某个查询 q q q 与某个键 k k k注意力权重 ,一个查询 q q q 与一组键 k k k 的注意力权重是各项非负且和为1的概率分布,从数学上看, α ( q , k i ) \alpha(q,k_i) α(q,ki) 仅仅是 a ( q , k i ) a(q,k_i) a(q,ki) 通过 s o f t m a x softmax softmax 转换而来的,故而重点在于如何得到 a ( q , k i ) a(q,k_i) a(q,ki)

a ( q , k i ) a(q,k_i) a(q,ki) 称为 注意力评分函数 ,可以看作查询与键的某种相似度(比如上例中的亲密度)。它决定了如何对查询 q q q 与键 k k k 进行计算以抽取两者的相关性并得到注意力权重。

若以自然语言模型为例,输入样本通常是包含多个词元的序列,其中每个词元用一个经过词嵌入的向量表示,查询与键值对就是取自词元;当注意力评分函数 a ( q , k ) a(q,k) a(q,k) 选的是 缩放点积注意力(scaled dot-product attention) 时,就是对两个向量 q 、 k q、k qk 执行点积运算,结果得到的是一个标量;再经 s o f t m a x softmax softmax 转换成注意力权重;使用注意力权重对值 v v v 进行加权平均,得到的就是考虑了所有输入的输出,一般是一个与输入同维度的向量。

以上就是注意力机制的运作过程。

自注意力机制

注意力机制中的查询、键值对可以来自相同或不同的对象,以中-英NMT为例,假设:

源句为 ‘深度学习’ ,词元化(忽略bos等特殊符号)后为: ‘深’ ‘度’ ‘学’ ‘习’

目标句为 ‘deep learning’ ,词元化后为: ‘deep’ ‘learning’

若以目标句作为查询 q q q 、源句作为键值对 k 、 v k、v kv ,则查询、键值对就是来自不同的对象,可以得到 2×4=8 个注意力权重(目标句2个词元对应2个q、源句4个词元对应4个键值对k,v,每个q与每个k会得到一个注意力权重),该注意力权重蕴含着源句每个词元对翻译结果生成的目标句每个词元的贡献或影响程度信息。这也是transformer中解码器与编码器之间注意力的计算方式。

这里顺便提一下,上面提到注意力评分函数有多种选择,其实也是要看查询 q q q 与键值对 k 、 v k、v kv 的维度是否相同,对于查询、键值对维度相同的,使用缩放点积注意力评分函数简单又高效(因为点积运算要求维度相同);但对于查询、键值对来自不同对象这种情况,很可能存在查询与键值对维度不同的情况,这时就不能简单使用缩放点积注意力评分函数了,而要寻求其他注意力评分函数的帮助,如 加性注意力评分函数

查询与键来自不同对象的计算图如下:


在这里插入图片描述

图3

若源句同时作为查询 q q q 和键值对 k 、 v k、v kv ,则查询、键值对就是来自相同的对象,这就是自注意力!这也是transformer中解码器、编码器内部各自注意力的计算方式。

自注意力不仅是指序列级的,词元也会和其自身做注意力计算,比如 α ( q 1 , k 1 ) \alpha(q_1,k_1) α(q1,k1)

上述源句的自注意力可以得到 4×4=16 个注意力权重(源句4个词元对应4个q与4个键值对k,v)。自注意力的查询 q q q 与键值对 k 、 v k、v kv 的维度是相同的,因此一般常用 缩放点积注意力评分函数

自注意力计算图如下:


在这里插入图片描述

图4

自注意力的细节

本节探讨的虽是自注意力的细节,其实与一般注意力大同小异。

下图是自注意力机制的整个运作过程:从词元向量组成的序列作为注意力输入开始,通过注意力参数矩阵转换得到 q 、 k 、 v q、k、v qkv ,接着使用注意力评分函数得到词元互相间的注意力权重,最后得到注意力汇聚后的输出。


在这里插入图片描述

图5

这里先指定一些对象名称,方便在下面各步表示数据形状,深度学习中一些关键节点计算前后的数据形状是很重要的,需要重点关注。

名称变量名
词嵌入维度d
小批量尺寸batch_size
查询/键值对数量num_steps
注意力参数维度query_size=key_size=value_size=num_hiddens=d
多头注意力头数num_heads

输入词嵌入

首先明确一点,自注意力中的 q 、 k 、 v q、k、v qkv 并不是原始的输入,事实上,原始输入序列经某种词元化方法后拆分成多个词元,再经词嵌入将词元转换成 d d d 维向量表示:


在这里插入图片描述

图6

用代码实现词嵌入。

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import torch
from torch import nn
from torch.nn import functional as F
import math
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inline#词嵌入
X=torch.tensor([0,1,2,3]) #输入序列(词元化后序列中的每个词元用数字索引表示),这里 0:深 1:度 2:学 3:习 ,词元数或步数为num_steps
print('原始输入序列形状:',X.shape) #这里输入序列只有一个样本,因此加上的批量batch_size维度为1
print('原始输入序列:',X)
vocab_size=X.shape[1] #词表大小
d=6 #词嵌入维度embedding = nn.Embedding(vocab_size, d) #词嵌入接口函数
X=embedding(X) #词嵌入后的词元向量表示(上图里的输入I)print('词嵌入后序列形状:',X.shape) #形状:(batch_size,steps,d)
print('词嵌入后序列:',X)

上面代码简要实现了词嵌入过程,索引0-4分别表示 ‘深’ ‘度’ ‘学’ ‘习’ 四个词元,在词嵌入后(注意批量维度为1)每个词元用一个初始化的词向量表示(这些也是可学习的参数),这就是自注意力的输入,本质上就是一个张量,形状为 (batch_size-小批量,num_steps-词元数,d-词嵌入维度)


在这里插入图片描述

图7

查询与键值

直接用词嵌入后的词元向量做 q 、 k 、 v q、k、v qkv 是不能表达出语义相关性的。词嵌入虽能通过学习使得词元自身具有语义,但词元之间的相关性同样也是需要学习的,因此引入注意力参数 W q 、 W k 、 W v W^q、W^k、W^v WqWkWv 来生成 q 、 k 、 v q、k、v qkv ,并通过学习来表示词元之间的相关性。

W q 、 W k 、 W v W^q、W^k、W^v WqWkWv 是序列词元共享的一组参数,它们的形状分别为 (num_hiddens, query_size)、(num_hiddens, key_size)、(num_hiddens, value_size) ,前面说过,在transformer模型架构的自注意力中,输入经注意力操作后的输出维度不变,所以 d=num_hiddens=query_size=key_size=value_size

已知输入序列形状为 (batch_size,num_steps,d) ,所以与 W q 、 W k 、 W v W^q、W^k、W^v WqWkWv 计算后:

q q q 的形状为 (batch_size,num_steps-查询数 ,query_size)

k k k 的形状为 (batch_size,num_steps-键值对数 ,key_size)

v v v 的形状为 (batch_size,num_steps-键值对数 ,value_size)

用nn.Linear生成注意力参数 W q 、 W k 、 W v W^q、W^k、W^v WqWkWv

#注意力参数
query_size, key_size, value_size, num_hiddens = d, d, d, d
W_q = nn.Linear(query_size, num_hiddens, bias=False)
W_k = nn.Linear(key_size, num_hiddens, bias=False)
W_v = nn.Linear(value_size, num_hiddens, bias=False)print('W_q形状:',W_q.weight.shape)
print('W_q:',W_q.weight.T)
print('W_k形状:',W_k.weight.shape)
print('W_k:',W_k.weight.T)
print('W_v形状:',W_v.weight.shape)
print('W_v:',W_v.weight.T)

在这里插入图片描述

图8

有了输入与注意力参数,下面代码生成查询与键值对。

#生成查询与键值对
queries = W_q(X) #查询q
keys = W_k(X) #键k
values = W_v(X) #值v#自注意力中查询、键、值的形状与X相同
print('q的形状:',queries.shape) #q的形状(batch_size,num_steps-查询数,query_size)
print('k的形状:',keys.shape) #k的形状(batch_size,num_steps-键值对数,key_size)
print('v的形状:',values.shape) #v的形状(batch_size,num_steps-键值对数,values_size)queries #查看查询的值

在这里插入图片描述

图9

现在来看看 q q q 的计算细节。以词元 ‘深’ 为例,其查询 q q q X 、 W q X、W^q XWq 矩阵乘法生成结果的其中一行,如下图所示。

可以看出 q q q 的最后一维取决于 W q W^q Wq 中的 query_size k 、 v k、v kv 的计算依此类推。


在这里插入图片描述

图10

注意力评分

对于 q 、 k 、 v q、k、v qkv ,可以使用特定的注意力评分函数计算注意力分数,Transformer的自注意力一般选择缩放点积注意力评分函数,数学公式为:

a ( q , k ) = q k T / d k a(q,k)=qk^T/ \sqrt{d_k} a(q,k)=qkT/dk

其实就是 q q q k k k 的转置这两个向量的点积,点积某种程度上可以表示两个向量的相似度,比如余弦相似度就是两个向量点积再除以它们的模,这里略去了模,影响不大。这里所有词元的 q 、 k q、k qk 将并行做点积运算,缩放点积注意力评分函数的优势就是简单高效。除以 d \sqrt{d} d 是为了减少向量长度对点积方差的影响。

每个词元的 q q q 都要和每个词元的 k k k 做点积得到一个权重,上面已知:

q q q 的形状为:(batch_size, num_steps-查询数, query_size);

k k k 的形状为:(batch_size, num_steps-键值对数, key_size);

因此可以预期最后得到的注意力权重形状为:(batch_size, num_steps–4个查询, num_steps–4个键值对)。

在与查询 q q q 计算前,键 k k k 需要转置一下。

#k的转置
keys.transpose(1,2).shape
keys.transpose(1,2) #转置最后两个维度,转置后形状为(batch_size, key_size, num_steps-键值对数)

在这里插入图片描述

图11

计算注意力评分。

#计算注意力评分
attention_scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
attention_scores.shape
attention_scores

在这里插入图片描述

图12

以词元 ‘深’ 的查询 q q q (下图queries红框里第一行)为例,其与所有词元的键 k k k 的注意力分数的计算如下图所示,这里因为 query_size=key_size,所以可以直接点积。


在这里插入图片描述

图13

注意力权重

用注意力分数经 s o f t m a x softmax softmax 转换即可得到注意力权重,4个查询、4个键,共得到16个注意力权重值。

#注意力权重
attention_weights=nn.functional.softmax(attention_scores, dim=-1)  
attention_weights.shape
attention_weights

在这里插入图片描述

图14

画出注意力权重的热力图,注意当前注意力参数是初始化未经训练的,表现不出有用的相关性。

#注意力热力图
def attention_heatmap(weights):backend_inline.set_matplotlib_formats('svg')rows, cols = weights.shape[0], weights.shape[1]fig, axes = plt.subplots(rows, cols, figsize=(3.0, 3.0), sharex=True, sharey=True, squeeze=False)for i, (row_axes, row_matrices) in enumerate(zip(axes, weights)):for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):pcm = ax.imshow(matrix.detach().numpy(), cmap='Reds')if i == rows - 1:ax.set_xlabel('Keys')if j == 0:ax.set_ylabel('Queries')fig.colorbar(pcm, ax=axes, shrink=0.6)attention_heatmap(attention_weights.reshape((1, 1, 4, 4)))

在这里插入图片描述

图15

掩蔽注意力

对于前文得到的注意力分数,有时可能不是每个分数都应当参与注意力权重的计算。

比如在语言模型中,输入序列的长度不一(不同序列的词元数量可能不同),但为了方便小批量计算,序列会被长切短补到相同长度(相同数量的词元),序列长度不够时会在最后面以特殊词元 pad 填补。

然而这些填补的词元没有实际意义,自然不该在注意力权重中分得一杯羹,为了达到这个目的,需将填补位置的注意力分数替换成一个很大的负值,比如-1e6,使对应位置的注意力权重为0(因为很大的负值经softmax后输出为0),从而实现遮蔽注意力权重的效用。

下图是上文生成的注意力分数,我们现在假设原来的 ‘深’ ‘度’ ‘学’ ‘习’ 四个词元中的 ‘习’ 是用 pad 填充的,这时每个查询(查询将仍是4个,包含补充词元)对应的有实际意义的键都只有三个,因此每个查询只应使用红色虚线框里面的三个值来计算注意力权重,灰色虚线框里的值在经softmax前需替换成-1e6。


在这里插入图片描述

图16

下面用代码实现掩蔽注意力。

valid_lens=torch.tensor([3,3,3,3])
attention_shape=attention_scores.shape
attention_scores=attention_scores.reshape(-1, attention_shape[-1])
mask = torch.arange((attention_shape[1]), dtype=torch.float32)[None, :] < valid_lens[:, None] #要遮蔽的位置
attention_scores[~mask] = -1e6 #遮蔽位置替换为-1e6print('遮蔽后的注意力分数:\n')
attention_scores
print('遮蔽后的注意力权重:\n')
attention_weights=nn.functional.softmax(attention_scores.reshape(attention_shape), dim=-1)
attention_weights.shape
attention_weights

下图可以看出,查询 q q q 与填补的键 k k k 间的注意力权重为0。


在这里插入图片描述

图17

注意力汇聚

注意力汇聚就是注意力权重对值 v v v 的加权平均,与文章开头所举例子里计算票数的方式一致。已知:

注意力权重形状:(batch_size, num_steps-查询数, num_steps-键值对数);

v v v 形状:(batch_size, num_steps-键值对数, value_size)。

print('注意力权重:\n')
attention_weights.shape
attention_weights
print('值v:\n')
values.shape
valuesprint('输出:\n')
output=torch.bmm(attention_weights, values)
output.shape
output

在这里插入图片描述

图18

计算细节见下图:


在这里插入图片描述

图19

汇聚后输出 I ′ I' I 的形状为:(batch_size, num_steps-查询数, value_size),最后的维度是值 v a l u e s values values 中的 value_size

自注意力初始输入 I I I 的形状是(batch_size,num_steps,d),因为 value_size=d ,因此初始输入经过自注意力计算后的输出的形状不变(词元也是一一对应),这样设计的初衷是为了方便计算。比如在transformer中,自注意力存在于编码器与解码器中,通常transformer模型里会堆叠多个编码器/解码器块,输入输出形状保持一致可以使上一层的输出继续作为下一层的输入,使得模型的体量可以灵活控制。

多头注意力

上述可以说只是单头注意力的计算过程,因为只生成了一组 q 、 k 、 v q、k、v qkv,查询 q q q 从一组键 k k k 上仅能抽取一种相关性。这也许并不太够,我们希望能抽取更多的查询与键的相关性以使注意力包含更丰富的语义信息,有点类似卷积神经网络中的卷积核,不同卷积核能识别不同的模式。

既然一组 q 、 k 、 v q、k、v qkv 能找到一种相关性,如果有多组不同的 q 、 k 、 v q、k、v qkv ,自然就可以找到多种相关性,这就是多头注意力机制。

那么不同的 q 、 k 、 v q、k、v qkv 如何得到?简单来说,对同一个输入 I I I ,若想得到不同的 q 、 k 、 v q、k、v qkv ,只需使用多组注意力参数 W q 、 W k 、 W v W^q、W^k、W^v WqWkWv 即可。

下图演示了具有2个头(num_heads)的多头注意力计算过程,与上面不同点在于2个头分别对应一组注意力参数(不同颜色标示),除了注意力参数不同外,其他的计算过程和单头注意力完全相同,只是对于每个头最后的输出 I ′ I' I ,需要将它们连结起来作为最终输出。注意连结后输出的维度要和开始时的输入 I I I 一样是 d d d ,因此在此例中,每个头自身输出的维度都是 d 2 d \over 2 2d ,事实上,如果有 n n n 个头,则每个头输出的维度将是 d n d \over n nd


在这里插入图片描述

图20

上图是为了方便理解多头注意力而画的计算过程,实际上从效率方面考虑,并不如此,为了并行计算,会将多个头使用的多组注意力参数 $W^q、W^k、W^v$ 拼接起来,只做一次计算,如下图所示:

在这里插入图片描述

图21

来看下具体的计算过程。输入 I I I 是不变的,注意力参数 W q 、 W k 、 W v W^q、W^k、W^v WqWkWv 也保持不变,只是它们现在是由 W 1 q 、 W 1 k 、 W 1 v W_1^q、W_1^k、W_1^v W1qW1kW1v W 2 q 、 W 2 k 、 W 2 v W_2^q、W_2^k、W_2^v W2qW2kW2v 这两组注意力参数拼接而成。


在这里插入图片描述

图22


q u e r i e s 、 k e y s 、 v a l u e s queries、keys、values querieskeysvalues 的计算也与之前一样,只是为了方便多头注意力的并行计算,需要转换一下形状。

print('queries初始形状',queries.shape) #(batch_size, num_steps, query_size)
queriesnum_heads=2
#q形状转换
queries=queries.reshape(queries.shape[0], queries.shape[1], num_heads, -1)
queries=queries.permute(0, 2, 1, 3)
queries=queries.reshape(-1, queries.shape[2], queries.shape[3])
#k形状转换
keys=keys.reshape(keys.shape[0], keys.shape[1], num_heads, -1)
keys=keys.permute(0, 2, 1, 3)
keys=keys.reshape(-1, keys.shape[2], keys.shape[3])
#v形状转换
values=values.reshape(values.shape[0], values.shape[1], num_heads, -1)
values=values.permute(0, 2, 1, 3)
values=values.reshape(-1, values.shape[2], values.shape[3])print('queries被reshape后的形状',queries.shape) #(batch_size*num_heads, num_steps, query_size/num_heads)
queriesprint('keys被reshape后的形状',keys.shape) #(batch_size*num_heads, num_steps, query_size/num_heads)
keys

下图展示了形状转换过程:


在这里插入图片描述

图23

剩余部分的计算与上面单头完全一样,只是最后每个头都有一组注意力权重,且需要连结起来。连结后的多头注意力结果与单头的结果并不相同,因为在计算注意力分数时,多个头的 q 、 k q、k qk 的维度已经被多头均分,所以得到的注意力分数已经不同了,不过本身多头注意力就是为了抽取 q 、 k q、k qk 之间的不同相关性。

#计算注意力评分
attention_scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)#注意力权重
attention_weights=nn.functional.softmax(attention_scores, dim=-1)  #输出
output=torch.bmm(attention_weights, values)#多头注意力连结
output = output.reshape(-1, num_heads, output.shape[1], output.shape[2])
output = output.permute(0, 2, 1, 3)
output = output.reshape(output.shape[0], output.shape[1], -1)
output.shape
outputattention_heatmap(attention_weights.reshape((1, 2, 4, 4)))

在这里插入图片描述

图24

查看两组注意力权重的热力图,分布并不相同。


在这里插入图片描述

图25

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • redis群集的三种模式
  • Jenkins 通过 Version Number Plugin 自动生成和管理构建的版本号
  • crm如何做私域运营?
  • Harmony Next 文件命令操作(发送、读取、媒体文件查询)
  • 【Hot100】LeetCode—215. 数组中的第K个最大元素
  • Qt常用控件——QLineEdit
  • uts+uniapp踩坑记录(vue3项目
  • 美团面试题:生成字符串的不同方式
  • 期权有哪些开户免50万元验资的方法?怎么操作?
  • 《C++位域:在复杂数据结构中的精准驾驭与风险规避》
  • uniapp微信小程序开发踩坑日记:Pinia持久化报错Cannot read property ‘localStorage‘ of undefined
  • map与set
  • 基于SpringBoot的医院挂号预约管理系统
  • vulnhub靶机:Holynix: v1
  • Capital许可管理最佳实践
  • IE9 : DOM Exception: INVALID_CHARACTER_ERR (5)
  • @jsonView过滤属性
  • Android交互
  • ES10 特性的完整指南
  • IndexedDB
  • JavaScript 一些 DOM 的知识点
  • Just for fun——迅速写完快速排序
  • ReactNativeweexDeviceOne对比
  • session共享问题解决方案
  • Spring技术内幕笔记(2):Spring MVC 与 Web
  • Spring思维导图,让Spring不再难懂(mvc篇)
  • Vue 2.3、2.4 知识点小结
  • vue-router的history模式发布配置
  • 观察者模式实现非直接耦合
  • 后端_ThinkPHP5
  • 聊聊spring cloud的LoadBalancerAutoConfiguration
  • 模型微调
  • 前端攻城师
  • 我这样减少了26.5M Java内存!
  • 以太坊客户端Geth命令参数详解
  • 中文输入法与React文本输入框的问题与解决方案
  • 最近的计划
  • ​Redis 实现计数器和限速器的
  • ​虚拟化系列介绍(十)
  • #QT(串口助手-界面)
  • #ubuntu# #git# repository git config --global --add safe.directory
  • (13):Silverlight 2 数据与通信之WebRequest
  • (1综述)从零开始的嵌入式图像图像处理(PI+QT+OpenCV)实战演练
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (delphi11最新学习资料) Object Pascal 学习笔记---第7章第3节(封装和窗体)
  • (附源码)springboot美食分享系统 毕业设计 612231
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (学习总结16)C++模版2
  • (一)C语言之入门:使用Visual Studio Community 2022运行hello world
  • (转)程序员疫苗:代码注入
  • .aanva
  • .NET 8.0 中有哪些新的变化?
  • .NET CF命令行调试器MDbg入门(二) 设备模拟器
  • .NET CORE Aws S3 使用
  • .Net Core 微服务之Consul(二)-集群搭建