当前位置: 首页 > news >正文

基于Python的自然语言处理系列(14):TorchText + biGRU + Attention + Teacher Forcing

        在前几篇文章中,我们探索了序列到序列(seq2seq)模型的基础,并通过使用双向GRU和上下文向量改进了模型的表现。然而,模型仍然依赖一个固定的上下文向量,这意味着它必须从整个源句中压缩信息,导致在长句子的翻译中可能出现问题。

        在本篇文章中,我们将引入注意力机制来解决这个问题。注意力机制允许解码器在每一步解码时不仅仅依赖一个固定的上下文向量,而是能够动态地访问源句中的所有信息。这样,模型可以在解码过程中“关注”到最相关的词,从而提升翻译的准确性,尤其是长句子。

1. 背景

        在传统的seq2seq模型中,解码器仅依赖编码器生成的一个上下文向量。尽管我们通过双向GRU改进了模型,但上下文向量仍然需要压缩整个源句的信息,限制了模型的表现。

        为了解决这个问题,注意力机制通过计算源句中每个词的权重,让解码器能够动态地关注源句中的不同部分,而不仅仅是依赖一个固定的上下文向量。这不仅提升了模型对长句子的处理能力,还提高了翻译的准确性。

2. 数据加载与预处理

        我们继续使用TorchText加载Multi30k数据集,并使用spacy进行标记化处理。数据加载的流程与之前文章中的内容相似。

from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizerSRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'de'train = Multi30k(split=('train'), language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))token_transform = {}
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TRG_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')

        与前面相同,我们将数据集分为训练集、验证集和测试集,并将文本进行数值化处理。

3. 模型设计

        在这个模型中,我们将实现一个结合了双向GRU注意力机制的seq2seq模型。模型结构包括以下几个部分:

3.1 编码器(Encoder)

        首先,我们将构建编码器。这里我们使用双向GRU,将输入序列从左到右和从右到左进行编码。编码器输出的隐状态将作为解码器的初始隐状态,并传递给注意力机制。

class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, dropout):super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.GRU(emb_dim, hid_dim, bidirectional=True)self.fc = nn.Linear(hid_dim * 2, hid_dim)self.dropout = nn.Dropout(dropout)def forward(self, src):embedded = self.dropout(self.embedding(src))outputs, hidden = self.rnn(embedded)hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)))return outputs, hidden

3.2 注意力机制(Attention)

        注意力机制的作用是计算解码器当前隐状态与源句每个词的隐状态之间的权重,帮助解码器决定应该关注源句的哪些部分。权重越大,说明该词对当前解码步骤越重要。

class Attention(nn.Module):def __init__(self, hid_dim):super().__init__()self.v = nn.Linear(hid_dim, 1, bias=False)self.W = nn.Linear(hid_dim, hid_dim)self.U = nn.Linear(hid_dim * 2, hid_dim)def forward(self, hidden, encoder_outputs):batch_size = encoder_outputs.shape[1]src_len = encoder_outputs.shape[0]hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)encoder_outputs = encoder_outputs.permute(1, 0, 2)energy = torch.tanh(self.W(hidden) + self.U(encoder_outputs))attention = self.v(energy).squeeze(2)return F.softmax(attention, dim=1)

3.3 解码器(Decoder)

        解码器在每个时刻生成一个新的目标词。它使用注意力机制得到的加权源句信息,以及解码器自身的隐状态来生成目标词的预测。

class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hid_dim, dropout, attention):super().__init__()self.output_dim = output_dimself.attention = attentionself.embedding = nn.Embedding(output_dim, emb_dim)self.gru = nn.GRU((hid_dim * 2) + emb_dim, hid_dim)self.fc_out = nn.Linear((hid_dim * 2) + hid_dim + emb_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, hidden, encoder_outputs):input = input.unsqueeze(0)embedded = self.dropout(self.embedding(input))a = self.attention(hidden, encoder_outputs).unsqueeze(1)encoder_outputs = encoder_outputs.permute(1, 0, 2)weighted = torch.bmm(a, encoder_outputs).permute(1, 0, 2)rnn_input = torch.cat((embedded, weighted), dim=2)output, hidden = self.gru(rnn_input, hidden.unsqueeze(0))embedded = embedded.squeeze(0)output = output.squeeze(0)weighted = weighted.squeeze(0)prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))return prediction, hidden.squeeze(0)

3.4 Seq2Seq模型

        将编码器、解码器和注意力机制组合起来,我们构建了完整的seq2seq模型。

class Seq2SeqAttention(nn.Module):def __init__(self, encoder, decoder, device):super().__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, src, trg, teacher_forcing_ratio=0.5):batch_size = src.shape[1]trg_len = trg.shape[0]trg_vocab_size = self.decoder.output_dimoutputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)encoder_outputs, hidden = self.encoder(src)input_ = trg[0, :]for t in range(1, trg_len):output, hidden = self.decoder(input_, hidden, encoder_outputs)outputs[t] = outputtop1 = output.argmax(1)input_ = trg[t] if random.random() < teacher_forcing_ratio else top1return outputs

4. 训练与评估

        我们使用与前几篇文章相同的训练和评估函数。为了防止梯度爆炸,我们在训练过程中应用梯度裁剪。

def train(model, iterator, optimizer, criterion, clip):model.train()epoch_loss = 0for i, (src, trg) in enumerate(iterator):src, trg = src.to(device), trg.to(device)optimizer.zero_grad()output = model(src, trg)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), clip)optimizer.step()epoch_loss += loss.item()return epoch_loss / len(iterator)def evaluate(model, iterator, criterion):model.eval()epoch_loss = 0with torch.no_grad():for i, (src, trg) in enumerate(iterator):src, trg = src.to(device), trg.to(device)output = model(src, trg, 0)output_dim = output.shape[-1]output = output[1:].view(-1, output_dim)trg = trg[1:].view(-1)loss = criterion(output, trg)epoch_loss += loss.item()return epoch_loss / len(iterator)

训练模型:

for epoch in range(10):train_loss = train(model, train_loader, optimizer, criterion, 1)val_loss = evaluate(model, valid_loader, criterion)print(f'Epoch {epoch+1} | Train Loss: {train_loss:.3f}, Val Loss: {val_loss:.3f}')

结语

        通过在seq2seq模型中引入注意力机制,我们成功提升了模型对长句子的处理能力。在每一个解码步骤中,解码器能够灵活地访问源句中的每个词,而不再依赖一个固定的上下文向量。这大大减少了信息压缩问题,使得模型在翻译复杂句子时更加精准。

        尽管注意力机制为模型带来了显著的改进,但训练时间相对增加。尤其在处理长句子时,模型需要计算源句中每个词的注意力权重,增加了计算复杂度。

        在下一篇文章中,我们将结合双向GRU注意力机制以及Packed Padded SequencesMasking技术,进一步优化模型的训练过程。同时,我们将展示如何通过这些技术处理不同长度的输入序列,并通过可视化注意力权重,更深入地理解模型在解码时关注哪些词。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 深入理解Go语言的方法定义与使用
  • sqli-lab靶场学习(二)——Less8-10(盲注、时间盲注)
  • 前端开发之迭代器模式
  • 从数据仓库到数据中台再到数据飞轮:我了解的数据技术进化史
  • 代码管理-使用TortoiseGit同步项目到Github/Gitee
  • 运行npm install 时,卡在sill idealTree buildDeps没有反应
  • SCRM电商管理后台Axure高保真原型 源文件
  • 电脑提示丢失mfc140u.dll的详细解决方案,mfc140u.dll文件是什么
  • C++初阶:STL详解(五)——vector的模拟实现
  • 初中生物--7.生物圈中的绿色植物(二)
  • java项目之在线考试与学习交流网页平台源码(springboot)
  • QT 串口上位机读卡显示
  • 枚举(not二分)
  • TCP 和 UDP 协议的区别?
  • MySQL之约束
  • 9月CHINA-PUB-OPENDAY技术沙龙——IPHONE
  • Android交互
  • android图片蒙层
  • Angular6错误 Service: No provider for Renderer2
  • C# 免费离线人脸识别 2.0 Demo
  • Git学习与使用心得(1)—— 初始化
  • JavaScript/HTML5图表开发工具JavaScript Charts v3.19.6发布【附下载】
  • java中的hashCode
  • Laravel Telescope:优雅的应用调试工具
  • Spark RDD学习: aggregate函数
  • Vue 动态创建 component
  • 程序员该如何有效的找工作?
  • 对象引论
  • 买一台 iPhone X,还是创建一家未来的独角兽?
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 嵌入式文件系统
  • 一个SAP顾问在美国的这些年
  • 如何用纯 CSS 创作一个菱形 loader 动画
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • #Datawhale X 李宏毅苹果书 AI夏令营#3.13.2局部极小值与鞍点批量和动量
  • #设计模式#4.6 Flyweight(享元) 对象结构型模式
  • $.each()与$(selector).each()
  • (145)光线追踪距离场柔和阴影
  • (2)从源码角度聊聊Jetpack Navigator的工作流程
  • (2024最新)CentOS 7上在线安装MySQL 5.7|喂饭级教程
  • (C语言)输入自定义个数的整数,打印出最大值和最小值
  • (done) NLP “bag-of-words“ 方法 (带有二元分类和多元分类两个例子)词袋模型、BoW
  • (补充)IDEA项目结构
  • (附源码)计算机毕业设计SSM教师教学质量评价系统
  • (几何:六边形面积)编写程序,提示用户输入六边形的边长,然后显示它的面积。
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (六)软件测试分工
  • (一)Docker基本介绍
  • (自用)learnOpenGL学习总结-高级OpenGL-抗锯齿
  • (最全解法)输入一个整数,输出该数二进制表示中1的个数。
  • *算法训练(leetcode)第四十五天 | 101. 孤岛的总面积、102. 沉没孤岛、103. 水流问题、104. 建造最大岛屿
  • ./configure,make,make install的作用
  • .NET Core 2.1路线图
  • .Net调用Java编写的WebServices返回值为Null的解决方法(SoapUI工具测试有返回值)
  • .NET开源的一个小而快并且功能强大的 Windows 动态桌面软件 - DreamScene2