transformer中的build_attention_mask
build_attention_mask
方法的作用是构建一个因果注意力掩码,用于屏蔽 Transformer 模型中的未来位置。
因果注意力掩码的工作原理
因果注意力掩码通过将未来位置的注意力权重设置为负无穷大,从而确保这些位置的注意力得分在 softmax 计算中接近于零。具体来说,这个掩码矩阵是一个上三角矩阵,其中上三角部分(不包括对角线)被设置为负无穷大。这样,当计算第 𝑖个位置的注意力分数时,只会考虑位置
0 到 𝑖的内容,而忽略位置 𝑖+1及之后的位置,这对于生成任务(如语言模型)非常重要。下面是这个方法的实现和详细解释:
详细解释
-
掩码矩阵的构建:
- 掩码矩阵
mask
的形状为 ([context_length, context_length])。 mask.fill_(float("-inf"))
将矩阵的所有元素初始化为负无穷大。mask.triu_(1)
将上三角部分(不包括对角线)设置为零。
- 掩码矩阵
-
应用掩码:
- 在计算注意力分数时,这个掩码会被添加到注意力分数矩阵中。
- 由于被掩盖的部分被设置为负无穷大,它们在 softmax 计算中会得到接近零的权重,从而 effectively 被忽略。
示例
假设 context_length
为 5,那么生成的掩码 mask
将如下所示:
tensor([[ 0., -inf, -inf, -inf, -inf],[ 0., 0., -inf, -inf, -inf],[ 0., 0., 0., -inf, -inf],[ 0., 0., 0., 0., -inf],[ 0., 0., 0., 0., 0.]])
计算注意力分数时的效果
在计算注意力分数时,假设我们有以下示例:
- 输入序列:[x_0, x_1, x_2, x_3, x_4]
- 注意力权重矩阵形状:[context_length, context_length]
计算第 ( i ) 个位置的注意力分数时,将使用掩码矩阵对注意力权重进行修正:
未加掩码时的注意力权重矩阵:
[[a00, a01, a02, a03, a04],[a10, a11, a12, a13, a14],[a20, a21, a22, a23, a24],[a30, a31, a32, a33, a34],[a40, a41, a42, a43, a44]]加上掩码后的注意力权重矩阵:
[[a00, -inf, -inf, -inf, -inf],[a10, a11, -inf, -inf, -inf],[a20, a21, a22, -inf, -inf],[a30, a31, a32, a33, -inf],[a40, a41, a42, a43, a44]]
为什么可以确保位置 ( i ) 只能关注位置 ( 0 ) 到 ( i )
- 对于位置 ( 0 ),只有 ( a00 ) 会被保留,其他的都被设置为负无穷大。
- 对于位置 ( 1 ),只有 ( a10 ) 和 ( a11 ) 会被保留,其他的都被设置为负无穷大。
- 对于位置 ( 2 ),只有 ( a20 ), ( a21 ), ( a22 ) 会被保留,其他的都被设置为负无穷大。
- 对于位置 ( 3 ),只有 ( a30 ), ( a31 ), ( a32 ), ( a33 ) 会被保留,其他的都被设置为负无穷大。
- 对于位置 ( 4 ),所有的 ( a40, a41, a42, a43, a44 ) 都会被保留,因为它已经是最后一个位置。
这种掩码方式确保了模型在生成第 ( i ) 个位置的输出时,不会看到第 ( i+1 ) 及之后的位置的输入。
代码示例
以下是一个简单的代码示例,展示了如何使用 build_attention_mask
生成掩码并应用到注意力机制中:
import torchclass ExampleModel:def __init__(self, context_length):self.context_length = context_lengthdef build_attention_mask(self):mask = torch.empty(self.context_length, self.context_length)mask.fill_(float("-inf"))mask.triu_(1)return mask# 假设 context_length 为 5
context_length = 5
model = ExampleModel(context_length)
attention_mask = model.build_attention_mask()print(attention_mask)# 示例的注意力权重矩阵
attention_weights = torch.randn(context_length, context_length)# 加上掩码后的注意力权重矩阵
masked_attention_weights = attention_weights + attention_mask
print(masked_attention_weights)
总结
build_attention_mask
方法通过生成一个上三角掩码矩阵,确保了每个位置 ( i ) 只能关注位置 ( 0 ) 到 ( i ),而不能关注位置 ( i+1 ) 及之后的位置。这个机制通过在注意力分数计算中设置负无穷大,使得这些位置在 softmax 计算中得到接近零的权重,从而 effectively 被忽略。这对于生成任务(如语言模型)非常重要,确保模型在生成时只依赖已生成的部分,而不会看到未来的输入。