当前位置: 首页 > news >正文

生成模型的中Attention Mask说明

生成模型中的Attention Mask说明

最近在做文本生成任务,例如诗歌生成,问题生成,摘要生成等,使用了Bart模型,CPT模型,mt5模型,t5模型等。生成模型是基于Seq-to-Seq(Encoder-Decoder)结构,输入的文本经过Encoder编码得到向量再输入到Decoder解码生成一个文本。Encoder和Decoder会使用多层transformer 以及self-attention,需要注意Encoder和Decoder中的attention mask使用,在Decoder中的self-attention当前时刻t的词只能关注到时刻0到t-1的词,无法关注到时刻t+1的词,是因为使用了在计算self-attention的时候加入了矩阵MASK M M M的右上角被mask为 − ∞ -\infty ,从而使得Decoder进行更好的生成。Decoder中的MASK M M M矩阵是由UNIfied pre-trained Language Model(UniLM)提出来的,对应于UniLM中的Seq-to-Seq LM。下面介绍Seq-to-Seq LM原理以及在bart中的使用。

Seq-to-Seq LM 介绍

Seq-to-Seq LM是UniLM提出来的,其中UniLM中有三种LM方式分别为Unidirectional LM,Bidirectional LM和Sequence-to-Sequence LM。如下图1所示。Unidirectional LM包括left-to-right and right-to-left LM,例如left-to-right LM当前词计算attention的时候只能使用当前词前面的词,后面的词无法使用到。
unilm mask

在 Seq-to-Seq LM中的self-attention计算中加入了一个MASK 矩阵 M M M,这个MASK 矩阵 M M M的右上角的元素是 − ∞ -\infty ,左下角的元素为0。
1. Multi-Layer Transformer
输入向量 { x i } i = 1 ∣ x ∣ \{x_{i}\}_{i=1}^{|x|} {xi}i=1x,第0层的transformer的输出状态 H 0 H^{0} H0记为 H 0 = [ x 1 , x 2 , … , x ∣ x ∣ ] H^{0} = [x_{1}, x_{2}, \dots, x_{|x|}] H0=[x1,x2,,xx] L L L层的Transformer的结果记为 H l = T r a n s f o r m e r l H l − 1 H^{l} = Transformer_{l} H^{l-1} Hl=TransformerlHl1 l ∈ [ 1 , L ] l\in[1, L] l[1,L] l l l层的self-attention计算过程如下:
Q = H l − 1 W l Q K = H l − 1 W l K V = H l − 1 W l V M i j = { 0 , a l l o w t o a t t e n d − ∞ , p r e v e n t f r o m a t t e n d i n g A l = s o f t m a x ( Q K T d k + M ) V Q = H^{l -1}W_{l}^{Q} \\ K = H^{l -1}W_{l}^{K} \\ V = H^{l -1}W_{l}^{V} \\ M_{ij}=\left\{ \begin{aligned} 0, & & allow to attend \\ -\infty, & & prevent from attending \end{aligned} \right. \\ A_{l} = softmax(\frac{QK^{T}}{\sqrt{d_{k}}} + M)V Q=Hl1WlQK=Hl1WlKV=Hl1WlVMij={0,,allowtoattendpreventfromattendingAl=softmax(dk QKT+M)V
其中 H l − 1 ∈ R ∣ x ∣ × d h H^{l-1}\in R^{|x|\times d_{h}} Hl1Rx×dh W l Q , W l K , W l V ∈ R d h × d k W_{l}^{Q},W_{l}^{K},W_{l}^{V}\in R^{d_{h}\times d_{k}} WlQ,WlK,WlVRdh×dk M ∈ R ∣ x ∣ × ∣ x ∣ M\in R^{|x|\times|x|} MRx×x
已UniLM的输入句子 S 1 S_{1} S1为[SOS] t1 t2 [EOS], 输出的句子 S 2 S_{2} S2为t3 t4 t5 [EOS],将 S 1 S_{1} S1 S 2 S_{2} S2拼接后得到句子[SOS] t1 t2 [EOS] t3 t4 t5 [EOS]输入模型,MASK M M M如下:
在这里插入图片描述
上图中矩阵MASK M M M右上角阴部部分的元素均为 − ∞ -\infty ,在self-attention中的softmax的时候加上MASK M M M使得self-attention关注不到 S 2 S_{2} S2的句子,使得模型有更好的生成能力。下面介绍MASK M M M 在bart中的使用。

Bart中Encoder和Decoder中的MASK M M M介绍

1. Bart中的self-attention
Bart中的self-attention计算和上面介绍的UniLM中的attention计算方式一样,在attention计算softmax的时候加入了MASK M M M矩阵,在代码中用attention mask代替,代码如下:
在这里插入图片描述

2. 在Bart中的Encoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],把attention mask中的元素为0的变为一个很大的负数,元素原来是1的位置处的元素为0,避免attention mask中元素为0的位置处的padding token对句子间的词之间的相关联程度的影响,相关代码如下:

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)

结果如下:
在这里插入图片描述
3. Bart中的Decoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],在把attention mask中的元素为0设置为 ∞ \infty 或者为一个负无穷大的数,相关代码如下

def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), float("-inf"))
    mask_cond = torch.arange(mask.size(-1))
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
       # create causal mask
       # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
       combined_attention_mask = None
       if input_shape[-1] > 1:
           combined_attention_mask = _make_causal_mask(
               input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
           ).to(self.device)

       if attention_mask is not None:
           # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
           expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
           combined_attention_mask = (
               expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
           )

       return combined_attention_mask

decoder中 MASK M M M 结果如下:
在这里插入图片描述
在生成模型中decoder中的self-attention计算加入MASK M M M矩阵,使得模型具有更好的生成能力。如有错误,欢迎指证。

相关文章:

  • java毕业设计企业固定资产管理系统源码+lw文档+mybatis+系统+mysql数据库+调试
  • Java---Java Web---JSP
  • opencv 机器学习-人脸识别
  • JavaScript的函数
  • java基于springboot+vue基本微信小程序的乒乓球课程管理系统 uniapp小程序
  • 安装数据库中间件——Mycat
  • 爬虫之Scrapy框架
  • 哈工大李治军老师操作系统笔记【23】:内存换出(Learning OS Concepts By Coding Them !)
  • Ubuntu 20.04 设置开机自启脚本
  • Vue2封装评论组件详细讲解
  • java-php-python-springboot校园新闻趣事计算机毕业设计
  • 使用Docker Compose搭建WordPress博客
  • 【Linux篇】第十一篇——动静态库(动静态库的介绍+动静态库的打包与使用)
  • 多任务学习(MTL)--学习笔记
  • 前端性能优化方法与实战01 体系总览:性能优化体系及关键指标设定
  • 【JavaScript】通过闭包创建具有私有属性的实例对象
  • 3.7、@ResponseBody 和 @RestController
  • Codepen 每日精选(2018-3-25)
  • Django 博客开发教程 8 - 博客文章详情页
  • Java,console输出实时的转向GUI textbox
  • mysql innodb 索引使用指南
  • Travix是如何部署应用程序到Kubernetes上的
  • Vue 2.3、2.4 知识点小结
  • 聚类分析——Kmeans
  • 理解 C# 泛型接口中的协变与逆变(抗变)
  • 聊聊hikari连接池的leakDetectionThreshold
  • 使用agvtool更改app version/build
  • 视频flv转mp4最快的几种方法(就是不用格式工厂)
  • 探索 JS 中的模块化
  • 异步
  • TPG领衔财团投资轻奢珠宝品牌APM Monaco
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • !$boo在php中什么意思,php前戏
  • #pragma once
  • #我与Java虚拟机的故事#连载19:等我技术变强了,我会去看你的 ​
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (C语言)字符分类函数
  • (zhuan) 一些RL的文献(及笔记)
  • (动态规划)5. 最长回文子串 java解决
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (附源码)计算机毕业设计SSM疫情居家隔离服务系统
  • (论文阅读40-45)图像描述1
  • (五)Python 垃圾回收机制
  • (原創) 如何讓IE7按第二次Ctrl + Tab時,回到原來的索引標籤? (Web) (IE) (OS) (Windows)...
  • (转)shell调试方法
  • ****Linux下Mysql的安装和配置
  • ***测试-HTTP方法
  • .Net CF下精确的计时器
  • .NET Core 2.1路线图
  • .Net Core/.Net6/.Net8 ,启动配置/Program.cs 配置
  • .net websocket 获取http登录的用户_如何解密浏览器的登录密码?获取浏览器内用户信息?...
  • .NET 使用 JustAssembly 比较两个不同版本程序集的 API 变化
  • .Net程序猿乐Android发展---(10)框架布局FrameLayout
  • .Net高阶异常处理第二篇~~ dump进阶之MiniDumpWriter
  • .NET开发不可不知、不可不用的辅助类(一)