当前位置: 首页 > news >正文

机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用

目录

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

2.2验证上述封装的代码 

2.3 创建 添加了Bahdanau的decoder 

 2.4训练

 2.5预测

3.知识点个人理解 


 

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

class MultiHeadAttention(nn.Module):#初始化属性和方法def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):"""query_size_size: query_size的特征数featureskey_size: key_size的特征数featuresvalue_size: value_size的特征数featuresnum_hiddens:隐藏层的神经元的数量num_heads:多头注意力的header的数量dropout: 释放模型需要计算的参数的比例bias=False:没有偏差**kwargs : 不定长度的关键字参数"""super().__init__(**kwargs)#接收参数self.num_heads = num_heads#初始化注意力,    #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同self.attention = dltools.DotProductAttention(dropout)#初始化四个w模型参数self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):def transpose_qkv(X, num_heads):"""实现queries, keys, values的数据维度转化"""#输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)#这里,不能直接用reshape,需要索引维度,防止数据不能一一对应X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)   #将原维度的num_hiddens拆分成num_heads, -1,  -1相当于num_hiddens/num_heads的数值X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])  #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)def transpose_outputs(X, num_heads):"""逆转transpose_qkv的操作"""#此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, 查询数/键值对数量, num_heads,  num_hiddens/num_heads)return X.reshape(X.shape[0], X.shape[1], -1)  #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)#queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)#获取转换维度之后的queries, keys, values,queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)#若valid_len不为空,存在if valid_lens is not None:#将valid_lens重复数据self.num_heads次,在0维度上valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)#若为空,什么都不做,跳出if判断,继续执行其他代码#通过注意力函数获取输出outputs#outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)outputs = self.attention(queries, keys, values, valid_lens)#逆转outputs的维度outputs_concat = transpose_outputs(outputs, self.num_heads)return self.W_o(outputs_concat)

2.2验证上述封装的代码 

#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval()  #需要预测,加上
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.2, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])X = torch.ones((batch_size, num_queries, num_hiddens))  #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  #shape(2,6,100) attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

2.3 创建 添加了Bahdanau的decoder 

# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs : (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs (batch_size, num_steps, num_hiddens)# hidden_state: (num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# X : (batch_size, num_steps, vocab_size)X = self.embedding(X) # X : (batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddenscontext = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

 2.4训练

# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)net = dltools.EncoderDecoder(encoder, decoder)dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 2.5预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('trouvez tom .', []), bleu 0.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解 

 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 音视频入门基础:AAC专题(4)——ADTS格式的AAC裸流实例分析
  • 大健康裂变分销小程序开发
  • linux 之0号进程、1号进程、2号进程
  • 9月18日
  • 前端univer创建、编辑excel
  • Docker 以外置数据库方式部署禅道
  • .config、Kconfig、***_defconfig之间的关系和工作原理
  • Hadoop里面MapReduce的序列化与Java序列化比较
  • Java知识点小结3:内存回收
  • 关于c#中异步async和await的理解
  • PyTorch 图像分割模型教程
  • csdn漏洞测试
  • 大数据处理技术:HBase的安装与基本操作
  • 二级C语言2023-9易错题
  • 数据结构与算法-Trie树添加与搜索
  • (ckeditor+ckfinder用法)Jquery,js获取ckeditor值
  • [数据结构]链表的实现在PHP中
  • 【Leetcode】104. 二叉树的最大深度
  • 4. 路由到控制器 - Laravel从零开始教程
  • axios请求、和返回数据拦截,统一请求报错提示_012
  • Hibernate【inverse和cascade属性】知识要点
  • Java程序员幽默爆笑锦集
  • MYSQL 的 IF 函数
  • PHP 的 SAPI 是个什么东西
  • SpiderData 2019年2月23日 DApp数据排行榜
  • Unix命令
  • VuePress 静态网站生成
  • vuex 笔记整理
  • 和 || 运算
  • 基于 Babel 的 npm 包最小化设置
  • 力扣(LeetCode)965
  • 聊聊sentinel的DegradeSlot
  • 普通函数和构造函数的区别
  • 前端路由实现-history
  • 如何使用 JavaScript 解析 URL
  • 如何学习JavaEE,项目又该如何做?
  • 数据科学 第 3 章 11 字符串处理
  • 小程序01:wepy框架整合iview webapp UI
  • 延迟脚本的方式
  • # Panda3d 碰撞检测系统介绍
  • #APPINVENTOR学习记录
  • $.ajax()参数及用法
  • (20)目标检测算法之YOLOv5计算预选框、详解anchor计算
  • (Matlab)遗传算法优化的BP神经网络实现回归预测
  • (笔记)M1使用hombrew安装qemu
  • (二)斐波那契Fabonacci函数
  • (非本人原创)史记·柴静列传(r4笔记第65天)
  • (九十四)函数和二维数组
  • (三维重建学习)已有位姿放入colmap和3D Gaussian Splatting训练
  • (实测可用)(3)Git的使用——RT Thread Stdio添加的软件包,github与gitee冲突造成无法上传文件到gitee
  • (转)负载均衡,回话保持,cookie
  • .NET Core WebAPI中使用Log4net 日志级别分类并记录到数据库
  • .NET 的程序集加载上下文
  • .NET 反射 Reflect
  • .NET开发不可不知、不可不用的辅助类(三)(报表导出---终结版)