当前位置: 首页 > news >正文

【动手学深度学习-pytorch】9.2长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年.

门控记忆元 cell

  • 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)
  • 为了控制记忆元,我们需要许多门。输入门 输出门 遗忘门
  • 其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。 另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。 我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。 让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态 作为数据送入长短期记忆网络的门中, 如 图9.2.1所示。 它们由三个具有sigmoid激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在
的范围内。
在这里插入图片描述
在这里插入图片描述

候选记忆元

在这里插入图片描述

记忆元

在这里插入图片描述

隐状态

在这里插入图片描述

只有隐状态会传递到输出层,而记忆元完全属于内部信息

从零开始实现

现在,我们从零开始实现长短期记忆网络。 与 8.5节中的实验相同, 我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

如前所述,超参数num_hiddens定义隐藏单元的数量。 我们按照标准差
的高斯分布初始化权重,并将偏置项设为0.

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。 请注意,只有隐状态才会传递到输出层, 而记忆元不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化 8.5节中 引入的RNNModelScratch类来训练一个长短期记忆网络, 就如我们在 9.1节中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

使用高级API,我们可以直接实例化LSTM模型。 高级API封装了前文介绍的所有配置细节。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

总结

  • 长短期记忆网络,包含三个门:输入门、忘记门和遗忘门。其中遗忘门用于重置单元的内容,通过专用的机制决定什么时候记忆或者忽略状态中的输入。

  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。

  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

相关文章:

  • K8S的mountPath和subPath
  • LeetCode 206.反转链表
  • 如何在智能交通系统中使用物联网技术提高道路安全和效率
  • 怎么让ChatGPT批量写作原创文章
  • Springboot+MybatisPlus+EasyExcel实现文件导入数据
  • Mysql中的那些锁
  • 【跟着CHATGPT学习硬件外设 | 04】ADC
  • SVG XML 格式定义图形入门介绍
  • 【AI模型-机器学习工具部署】远程服务器配置Jupyter notebook或jupyter lab服务
  • kubernetes-k9s一个基于Linux 终端的集群管理工具
  • 微信小程序布局中的单位及使用
  • EXCEL 通过FILES函数获取指定路径中的所有文件名
  • Docker Desktop 在 Windows 上的安装和使用
  • 从TCP/IP协议到socket编程详解
  • 接口自动化框架搭建(四):pytest的使用
  • C++类中的特殊成员函数
  • Date型的使用
  • JavaScript 事件——“事件类型”中“HTML5事件”的注意要点
  • js 实现textarea输入字数提示
  • leetcode-27. Remove Element
  • oldjun 检测网站的经验
  • Web设计流程优化:网页效果图设计新思路
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 关于extract.autodesk.io的一些说明
  • 那些年我们用过的显示性能指标
  • 普通函数和构造函数的区别
  • 山寨一个 Promise
  • 深入体验bash on windows,在windows上搭建原生的linux开发环境,酷!
  • 优秀架构师必须掌握的架构思维
  • 阿里云IoT边缘计算助力企业零改造实现远程运维 ...
  • 好程序员大数据教程Hadoop全分布安装(非HA)
  • $jQuery 重写Alert样式方法
  • (03)光刻——半导体电路的绘制
  • (6)设计一个TimeMap
  • (力扣)1314.矩阵区域和
  • (学习日记)2024.03.12:UCOSIII第十四节:时基列表
  • (转)h264中avc和flv数据的解析
  • .bat文件调用java类的main方法
  • .net 4.0发布后不能正常显示图片问题
  • .Net 8.0 新的变化
  • .NET CORE 3.1 集成JWT鉴权和授权2
  • .NET 简介:跨平台、开源、高性能的开发平台
  • .net解析传过来的xml_DOM4J解析XML文件
  • @Autowired和@Resource的区别
  • @data注解_一枚 架构师 也不会用的Lombok注解,相见恨晚
  • @拔赤:Web前端开发十日谈
  • @软考考生,这份软考高分攻略你须知道
  • [ 渗透测试面试篇 ] 渗透测试面试题大集合(详解)(十)RCE (远程代码/命令执行漏洞)相关面试题
  • []FET-430SIM508 研究日志 11.3.31
  • []使用 Tortoise SVN 创建 Externals 外部引用目录
  • [14]内置对象
  • [AutoSar]BSW_OS 02 Autosar OS_STACK
  • [C# WPF] DataGrid选中行或选中单元格的背景和字体颜色修改
  • [C++][数据结构][算法]单链式结构的深拷贝
  • [EFI]Dell Latitude-7400电脑 Hackintosh 黑苹果efi引导文件