当前位置: 首页 > news >正文

机器翻译之创建Seq2Seq的编码器、解码器

1.创建编码器、解码器的基类

1.1创建编码器的基类

from torch import nn#构建编码器的基类
class Encoder(nn.Module):   #继承父类nn.Moduledef __init__(self, **kwargs):   #**kwargs:不定常的关键字参数super().__init__(**kwargs)def forward(self, X, *args):  #*args:不定常的位置参数#若继承了Encoder这个基类,就必须实现forward(),否则就会报下这个错raise  NotImplementedError          

1.2创建解码器的基类

#创建解码器的基类
#创建解码器的基类比创建编码器的基类多一个 state的初始化
class Decoder(nn.Module):def __init__(self, **kwargs):super().__init__(**kwargs)#初始化statedef init_state(self, enc_outputs, *args):raise NotImplementedError#前向传播,解码器比编码器多传入一个statedef forward(self, X, state):raise NotImplementedError

 1.3合并编码器和解码器的基类

class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, **kwargs):super().__init__(**kwargs)self.encoder = encoderself.decoder = decoderdef forward(self, enc_X, dec_X, *args):"""enc_X:编码器需传入的数据dec_X:解码器需传入的数据"""enc_outputs = self.encoder(enc_X, *args)dec_state = self.decoder.init_state(enc_outputs, *args)return self.decoder(dec_X, dec_state)

 2.基于上述基类,正式创建Seq2Seq编码器与解码器的类

import collections
import math
import torch
import dltools

2.1创建Seq2Seq的编码器类 

class Seq2SeqEncoder(Encoder):  #继承父类Encoderdef __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)"""vocab_size:词汇表大小embed_size:嵌入层大小num_hiddens:隐藏层的神经元数量num_layers:隐藏层的层数dropout=0 : 默认所有的神经元参与计算"""#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)def forward(self, X, *args):#在进行embedding之前,X的shape=(batch_size, num_steps, vocab_size)X = self.embedding(X) #X经过embedding处理,X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #经过permute调换维度之后,X的shape=(num_steps, batch_size, embed_size)#此时, pytorch 会自动完成隐藏状态的初始化,即0, 不需要手动传入stateoutputs, state = self.rnn(X)#outputs的shape=(num_steps, batch_size, num_hiddens) ,最后一维是神经元的数量#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试代码
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
encoder.eval()
# batch_size=4, num_steps=7
X = torch.zeros((4, 7), dtype=torch.long)
outputs, state = encoder(X)print(outputs.shape, state.shape)
torch.Size([7, 4, 16]) torch.Size([2, 4, 16])

2.2 创建Seq2Seq的解码器类

class Seq2SeqDecoder(Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super().__init__(**kwargs)#初始化嵌入层self.embedding = nn.Embedding(vocab_size, embed_size)#初始化神经网络层self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#初始化输出层self.dense = nn.Linear(num_hiddens, vocab_size)#定义函数:获取状态statedef init_state(self, enc_outputs, *args):#编码器输出的结果有两个,第二个为statereturn enc_outputs[1]#前向传播def forward(self, X, state):#X的原始shape=(batch_size, num_steps, vocab_size)X = self.embedding(X)  #X的shape=(batch_size, num_steps, embed_size)X = X.permute(1, 0, 2)  #调整数据维度, X的shape=(num_steps, batch_size, embed_size)# 把X和state拼接到一起. 方便计算. # X现在的形状(num_steps, batch_size, embed_size) , # state的形状(batch_size, num_hiddens)# 要把state的形状扩充成三维. 变成(num_steps, batch_size, num_hiddens)context = state[-1].repeat(X.shape[0], 1, 1)  #扩充X.shape[0]=num_steps次,1:所对应的维度不变X_and_context = torch.cat((X, context), 2) #按照索引为2的维度合并#此时,X_and_context的shape=(num_steps, batch_size, embed_size+num_hiddens)#神经网络层outputs, state = self.rnn(X_and_context, state)#输出层outputs = self.dense(outputs).permute(1, 0, 2) #将数据维度重新调换过来#outputs的shape=(batch_size, num_steps, vocab_size)#state的shape=(num_layers, batch_size, num_hiddens)return outputs, state
#测试
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=32, num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
outputs, state = decoder(X, state)
outputs.shape, state.shape
(torch.Size([4, 7, 10]), torch.Size([2, 4, 32]))

3.编码器 、解码器理论图

 

4.知识点个人理解

 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • C++11——function与bind
  • Vue3 : Pinia的性质与作用
  • react jsx
  • Java基础(中)
  • 进阶版水仙花数水是指一个n位数,各个位数字的n次方之和等于该数字本身
  • 《华为三台交换机堆叠的详细命令行配置示例》
  • flink自定义process,使用状态求历史总和(scala)
  • OpenSSH从7.4升级到9.8的过程 亲测--图文详解
  • 安卓13设置动态显示隐藏第一页的某一项 动态显示隐藏无障碍 android13设置动态显示隐藏第一页的某一项
  • 4款音频转文字在线转换工具帮你解锁新的记录模式。
  • RabbitMQ 高级特性——发送方确认
  • 力扣239 滑动窗口最大值 Java版本
  • C++ 新特性
  • Ceph官方文档_02_Ceph初学者指南
  • 基于php的小说阅读系统
  • python3.6+scrapy+mysql 爬虫实战
  • 230. Kth Smallest Element in a BST
  • android图片蒙层
  • CSS居中完全指南——构建CSS居中决策树
  • iBatis和MyBatis在使用ResultMap对应关系时的区别
  • nfs客户端进程变D,延伸linux的lock
  • Python3爬取英雄联盟英雄皮肤大图
  • thinkphp5.1 easywechat4 微信第三方开放平台
  • vue 个人积累(使用工具,组件)
  • vue从创建到完整的饿了么(11)组件的使用(svg图标及watch的简单使用)
  • vue从入门到进阶:计算属性computed与侦听器watch(三)
  • weex踩坑之旅第一弹 ~ 搭建具有入口文件的weex脚手架
  • 动态魔术使用DBMS_SQL
  • 欢迎参加第二届中国游戏开发者大会
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 判断客户端类型,Android,iOS,PC
  • 前嗅ForeSpider教程:创建模板
  • 入口文件开始,分析Vue源码实现
  • 深入 Nginx 之配置篇
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • 我这样减少了26.5M Java内存!
  • 线性表及其算法(java实现)
  • 进程与线程(三)——进程/线程间通信
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ## 1.3.Git命令
  • #vue3 实现前端下载excel文件模板功能
  • #我与Java虚拟机的故事#连载14:挑战高薪面试必看
  • (1)(1.19) TeraRanger One/EVO测距仪
  • (2024,Vision-LSTM,ViL,xLSTM,ViT,ViM,双向扫描)xLSTM 作为通用视觉骨干
  • (C语言)输入自定义个数的整数,打印出最大值和最小值
  • (env: Windows,mp,1.06.2308310; lib: 3.2.4) uniapp微信小程序
  • (pojstep1.1.2)2654(直叙式模拟)
  • (pycharm)安装python库函数Matplotlib步骤
  • (WSI分类)WSI分类文献小综述 2024
  • (二)hibernate配置管理
  • (附源码)springboot宠物管理系统 毕业设计 121654
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (转) RFS+AutoItLibrary测试web对话框
  • (转)linux下的时间函数使用