当前位置: 首页 > news >正文

机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类def __init__(self, **kwargs):super().__init__(**kwargs)@property    #装饰器, 定义的函数方法可以像类的属性一样被调用def attention_weights(self):#raise用于引发(或抛出)异常raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 #创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  #初始化属性和方法def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):"""vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数embed_size:嵌入层的大小:将输入数据处理成小批量的数据num_hiddens:隐藏层神经元的数量num_layers:循环网络的层数dropout=0:不释放模型的参数(比如:神经元)"""super().__init__(**kwargs)#初始化注意力机制的评分函数方法self.attention = dltools.AdditiveAttention(key_size=num_hiddens,query_size=num_hiddens, num_hiddens=num_hiddens,dropout=dropout)#初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)#初始化循环网络self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化线性层  (输出层)self.dense = nn.Linear(num_hiddens, vocab_size)#初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)def init_state(self, enc_outputs, enc_valid_lens, *args):#enc_outputs是一个元组(输出结果,隐藏状态)#outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)outputs, hidden_state = enc_outputs#返回一个元组(,),可以用一个变量接收#outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)#定义前向传播   (输入数据X,state)def forward(self, X, state):#变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度#enc_outputs的shape=(batch_size, num_steps, num_hiddens)#hidden_state的shape=(num_layers, batch_size, num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state#X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)#调换X的0维度和1维度数据X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)outputs, self._attention_weights = [], []  #创建空列表,用于存储数据for x in X:  #遍历每一批数据#获取query#hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)#hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)#通过注意力机制获取上下文序列context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)#用最后一个维度 拼接context, x 数据x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)#将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_stateout, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)#将输出结果添加到列表中outputs.append(out)self._attention_weights.append(self.attention_weights)outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):print('pred_seq:', pred_seq)print('label_seq:', label_seq)#将pred_seq, label_seq分别进行空格分隔pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')#获取pred_seq, label_seq的长度len_pred, len_label = len(pred_seq), len(label_seq)score = math.exp(min(0, 1 - (len_label / len_pred)))for n in range(1, k+1): #n的取值范围,  range()左闭右开num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[' '.join(label_tokens[i: i+n])] += 1for i in range(len_pred - n + 1):if label_subs[' '.join(pred_tokens[i: i+n])] > 0:num_matches += 1label_subs[' '.join(pred_tokens[i: i+n])] -=1score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))return score

 4.预测

import math
import collectionsengs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Linux操作系统面试题记录
  • 国内可以使用的ChatGPT服务【9月持续更新】
  • 机器之心 | 阿里云Qwen2.5发布!再登开源大模型王座,Qwen-Max性能逼近GPT-4o
  • 【研发日记】嵌入式处理器技能解锁(六)——ARM的Cortex-M4内核
  • 新书速览|NestJS全栈开发解析:快速上手与实践
  • 数据结构之二叉树查询
  • JetLinks物联网学习(前后端项目启动)
  • HarmonyOS开发者基础认证考试试题
  • 生信初学者教程(七):数据库
  • 【pytorch】RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED 报错
  • 【Unity踩坑】UI Image的fillAmount不起作用
  • Spring中的Web Service消费者集成(应该被淘汰的技术)
  • 【学习笔记】STM32F407探索者HAL库开发(五)F407时钟系统配置
  • MySQL高阶1907-按分类统计薪水
  • Codeforces Round 784 (Div. 4) Kotlin
  • 9月CHINA-PUB-OPENDAY技术沙龙——IPHONE
  • [LeetCode] Wiggle Sort
  • 【每日笔记】【Go学习笔记】2019-01-10 codis proxy处理流程
  • flask接收请求并推入栈
  • JavaWeb(学习笔记二)
  • JAVA之继承和多态
  • jquery cookie
  • JS函数式编程 数组部分风格 ES6版
  • laravel with 查询列表限制条数
  • 分享一份非常强势的Android面试题
  • 给初学者:JavaScript 中数组操作注意点
  • 关于Flux,Vuex,Redux的思考
  • 开发基于以太坊智能合约的DApp
  • 前端每日实战:61# 视频演示如何用纯 CSS 创作一只咖啡壶
  • 微信端页面使用-webkit-box和绝对定位时,元素上移的问题
  • 哈罗单车融资几十亿元,蚂蚁金服与春华资本加持 ...
  • ​DB-Engines 11月数据库排名:PostgreSQL坐稳同期涨幅榜冠军宝座
  • ​DB-Engines 12月数据库排名: PostgreSQL有望获得「2020年度数据库」荣誉?
  • ​软考-高级-信息系统项目管理师教程 第四版【第14章-项目沟通管理-思维导图】​
  • #{} 和 ${}区别
  • #java学习笔记(面向对象)----(未完结)
  • (160)时序收敛--->(10)时序收敛十
  • (SpringBoot)第二章:Spring创建和使用
  • (web自动化测试+python)1
  • (zt)基于Facebook和Flash平台的应用架构解析
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • (附源码)springboot炼糖厂地磅全自动控制系统 毕业设计 341357
  • (四)c52学习之旅-流水LED灯
  • (万字长文)Spring的核心知识尽揽其中
  • (转)Android学习系列(31)--App自动化之使用Ant编译项目多渠道打包
  • (转)大型网站架构演变和知识体系
  • (转)平衡树
  • (轉貼) 寄發紅帖基本原則(教育部禮儀司頒布) (雜項)
  • *1 计算机基础和操作系统基础及几大协议
  • ..thread“main“ com.fasterxml.jackson.databind.JsonMappingException: Jackson version is too old 2.3.1
  • .libPaths()设置包加载目录
  • .net 4.0 A potentially dangerous Request.Form value was detected from the client 的解决方案
  • .NET 的静态构造函数是否线程安全?答案是肯定的!
  • .net8.0与halcon编程环境构建
  • //usr/lib/libgdal.so.20:对‘sqlite3_column_table_name’未定义的引用