当前位置: 首页 > news >正文

Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input   dec_input    dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在这里插入图片描述

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

相关文章:

  • CSP-动态规划-最长公共子序列(LCS)
  • Zig、C、Rust的Pk1
  • C++ //练习 6.30 编译第200页的str_subrange函数,看看你的编译器是如何处理函数中的错误的。
  • λ-矩阵的多项式展开
  • Socket.D 开源输传协议 v2.4.0 发布
  • 基金分类
  • vLLM vs Text Generation Interface:大型语言模型服务框架的比较
  • C#使用密封类密封用户信息
  • 2023全球云计算市场份额排名
  • Oracle中怎么设置时区和系统时间
  • Bitcoin Bridge:治愈还是诅咒?
  • tsgctf-2021-lkgit-无锁竞争-userfaultfd
  • 电路设计(15)——篮球赛24秒违例倒计时报警器的proteus仿真
  • Flink从入门到实践(二):Flink DataStream API
  • 【深度学习】S2 数学基础 P1 线性代数(上)
  • 【跃迁之路】【463天】刻意练习系列222(2018.05.14)
  • Android Volley源码解析
  • E-HPC支持多队列管理和自动伸缩
  • JavaScript创建对象的四种方式
  • SSH 免密登录
  • TypeScript实现数据结构(一)栈,队列,链表
  • 包装类对象
  • 精彩代码 vue.js
  • 前端设计模式
  • 使用Gradle第一次构建Java程序
  • media数据库操作,可以进行增删改查,实现回收站,隐私照片功能 SharedPreferences存储地址:
  • 继 XDL 之后,阿里妈妈开源大规模分布式图表征学习框架 Euler ...
  • ​人工智能之父图灵诞辰纪念日,一起来看最受读者欢迎的AI技术好书
  • #DBA杂记1
  • #传输# #传输数据判断#
  • #我与Java虚拟机的故事#连载13:有这本书就够了
  • (Bean工厂的后处理器入门)学习Spring的第七天
  • (C++17) optional的使用
  • (C语言)字符分类函数
  • (floyd+补集) poj 3275
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (算法二)滑动窗口
  • (一)Linux+Windows下安装ffmpeg
  • (转)http协议
  • (转)linux自定义开机启动服务和chkconfig使用方法
  • **Java有哪些悲观锁的实现_乐观锁、悲观锁、Redis分布式锁和Zookeeper分布式锁的实现以及流程原理...
  • **PHP分步表单提交思路(分页表单提交)
  • .Net - 类的介绍
  • .net core Swagger 过滤部分Api
  • .NET CORE使用Redis分布式锁续命(续期)问题
  • .NET 的静态构造函数是否线程安全?答案是肯定的!
  • .NET 应用架构指导 V2 学习笔记(一) 软件架构的关键原则
  • .NET版Word处理控件Aspose.words功能演示:在ASP.NET MVC中创建MS Word编辑器
  • @converter 只能用mysql吗_python-MySQLConverter对象没有mysql-connector属性’...
  • [ IO.File ] FileSystemWatcher
  • [ 网络基础篇 ] MAP 迈普交换机常用命令详解
  • [20190401]关于semtimedop函数调用.txt
  • [acwing周赛复盘] 第 94 场周赛20230311
  • [Java]快速入门优先队列(堆)手撕相关面试题
  • [Leetcode] Permutations II