生成模型的中Attention Mask说明
生成模型中的Attention Mask说明
最近在做文本生成任务,例如诗歌生成,问题生成,摘要生成等,使用了Bart模型,CPT模型,mt5模型,t5模型等。生成模型是基于Seq-to-Seq(Encoder-Decoder)结构,输入的文本经过Encoder编码得到向量再输入到Decoder解码生成一个文本。Encoder和Decoder会使用多层transformer 以及self-attention,需要注意Encoder和Decoder中的attention mask使用,在Decoder中的self-attention当前时刻t的词只能关注到时刻0到t-1的词,无法关注到时刻t+1的词,是因为使用了在计算self-attention的时候加入了矩阵MASK M M M的右上角被mask为 − ∞ -\infty −∞,从而使得Decoder进行更好的生成。Decoder中的MASK M M M矩阵是由UNIfied pre-trained Language Model(UniLM)提出来的,对应于UniLM中的Seq-to-Seq LM。下面介绍Seq-to-Seq LM原理以及在bart中的使用。
Seq-to-Seq LM 介绍
Seq-to-Seq LM是UniLM提出来的,其中UniLM中有三种LM方式分别为Unidirectional LM,Bidirectional LM和Sequence-to-Sequence LM。如下图1所示。Unidirectional LM包括left-to-right and right-to-left LM,例如left-to-right LM当前词计算attention的时候只能使用当前词前面的词,后面的词无法使用到。
在 Seq-to-Seq LM中的self-attention计算中加入了一个MASK 矩阵
M
M
M,这个MASK 矩阵
M
M
M的右上角的元素是
−
∞
-\infty
−∞,左下角的元素为0。
1. Multi-Layer Transformer
输入向量
{
x
i
}
i
=
1
∣
x
∣
\{x_{i}\}_{i=1}^{|x|}
{xi}i=1∣x∣,第0层的transformer的输出状态
H
0
H^{0}
H0记为
H
0
=
[
x
1
,
x
2
,
…
,
x
∣
x
∣
]
H^{0} = [x_{1}, x_{2}, \dots, x_{|x|}]
H0=[x1,x2,…,x∣x∣]。
L
L
L层的Transformer的结果记为
H
l
=
T
r
a
n
s
f
o
r
m
e
r
l
H
l
−
1
H^{l} = Transformer_{l} H^{l-1}
Hl=TransformerlHl−1,
l
∈
[
1
,
L
]
l\in[1, L]
l∈[1,L]。
l
l
l层的self-attention计算过程如下:
Q
=
H
l
−
1
W
l
Q
K
=
H
l
−
1
W
l
K
V
=
H
l
−
1
W
l
V
M
i
j
=
{
0
,
a
l
l
o
w
t
o
a
t
t
e
n
d
−
∞
,
p
r
e
v
e
n
t
f
r
o
m
a
t
t
e
n
d
i
n
g
A
l
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
+
M
)
V
Q = H^{l -1}W_{l}^{Q} \\ K = H^{l -1}W_{l}^{K} \\ V = H^{l -1}W_{l}^{V} \\ M_{ij}=\left\{ \begin{aligned} 0, & & allow to attend \\ -\infty, & & prevent from attending \end{aligned} \right. \\ A_{l} = softmax(\frac{QK^{T}}{\sqrt{d_{k}}} + M)V
Q=Hl−1WlQK=Hl−1WlKV=Hl−1WlVMij={0,−∞,allowtoattendpreventfromattendingAl=softmax(dkQKT+M)V
其中
H
l
−
1
∈
R
∣
x
∣
×
d
h
H^{l-1}\in R^{|x|\times d_{h}}
Hl−1∈R∣x∣×dh,
W
l
Q
,
W
l
K
,
W
l
V
∈
R
d
h
×
d
k
W_{l}^{Q},W_{l}^{K},W_{l}^{V}\in R^{d_{h}\times d_{k}}
WlQ,WlK,WlV∈Rdh×dk,
M
∈
R
∣
x
∣
×
∣
x
∣
M\in R^{|x|\times|x|}
M∈R∣x∣×∣x∣
已UniLM的输入句子
S
1
S_{1}
S1为[SOS] t1 t2 [EOS], 输出的句子
S
2
S_{2}
S2为t3 t4 t5 [EOS],将
S
1
S_{1}
S1和
S
2
S_{2}
S2拼接后得到句子[SOS] t1 t2 [EOS] t3 t4 t5 [EOS]输入模型,MASK
M
M
M如下:
上图中矩阵MASK
M
M
M右上角阴部部分的元素均为
−
∞
-\infty
−∞,在self-attention中的softmax的时候加上MASK
M
M
M使得self-attention关注不到
S
2
S_{2}
S2的句子,使得模型有更好的生成能力。下面介绍MASK
M
M
M 在bart中的使用。
Bart中Encoder和Decoder中的MASK M M M介绍
1. Bart中的self-attention
Bart中的self-attention计算和上面介绍的UniLM中的attention计算方式一样,在attention计算softmax的时候加入了MASK
M
M
M矩阵,在代码中用attention mask代替,代码如下:
2. 在Bart中的Encoder中的MASK
M
M
M
在Encoder中的MASK
M
M
M 是根据attention mask 计算得到的,attention mask大小为
[
b
a
t
c
h
s
i
z
e
,
s
e
q
l
e
n
g
t
h
]
[batch_size, seq_length]
[batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为
[
b
a
t
c
h
s
i
z
e
,
1
,
s
e
q
l
e
n
g
t
h
,
s
e
q
l
e
n
g
t
h
]
[batch_size, 1, seq_length, seq_length]
[batchsize,1,seqlength,seqlength],把attention mask中的元素为0的变为一个很大的负数,元素原来是1的位置处的元素为0,避免attention mask中元素为0的位置处的padding token对句子间的词之间的相关联程度的影响,相关代码如下:
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
结果如下:
3. Bart中的Decoder中的MASK
M
M
M
在Encoder中的MASK
M
M
M 是根据attention mask 计算得到的,attention mask大小为
[
b
a
t
c
h
s
i
z
e
,
s
e
q
l
e
n
g
t
h
]
[batch_size, seq_length]
[batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为
[
b
a
t
c
h
s
i
z
e
,
1
,
s
e
q
l
e
n
g
t
h
,
s
e
q
l
e
n
g
t
h
]
[batch_size, 1, seq_length, seq_length]
[batchsize,1,seqlength,seqlength],在把attention mask中的元素为0设置为
∞
\infty
∞或者为一个负无穷大的数,相关代码如下
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
decoder中 MASK
M
M
M 结果如下:
在生成模型中decoder中的self-attention计算加入MASK
M
M
M矩阵,使得模型具有更好的生成能力。如有错误,欢迎指证。