当前位置: 首页 > news >正文

深度学习(PyTorch)——长短期记忆神经网络(LSTM)

一、LSTM网络

ng short term memory,即我们所称呼的LSTM,是为了解决长期以来问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。在标准RNN中,这个重复的结构模块只有一个非常简单的结构,例如一个tanh层

LSTM也有与RNN相似的循环结构,但是循环模块中不再是简单的网络,而是比较复杂的网络单
元。LSTM的循环模块主要有4个单元,以比较复杂的方式进行连接。
 

 先熟悉以下标记:

在上图中,每条线都承载着整个矢量,从一个节点的输出到另一个节点的输入。 粉色圆圈表示按点操作,如矢量加法,而黄色框表示学习的神经网络层。 合并的行表示串联,而分叉的行表示要复制的内容,并且副本到达不同的位置。

二、LSTM核心

每个LSTM的重复结构称之为一个细胞(cell),在LSTM中最关键的就是细胞的状态,下图中贯穿
的那条横线所表示的就是细胞状态。这条线的意思就是Ct-1先乘以一个系数,再线性叠加后从右侧输出。

门可以实现选择性地让信息通过,主要是通过一个 sigmoid 的神经层 和一个逐点相乘的操作来实现的。

sigmoid层输出的是0-1之间的数字,表示着每个成分能够通过门的比例,对应位数字为0表示不通过,数字1表示全通过。比如一个信息表示为向量[1, 2, 3, 4],sigmoid层的输出为[0.3, 0.5, 0.2,,0.4],那么信息通过此门后执行点乘操作,结果为[1, 2, 3, 4] .* [0.3, 0.5, 0.2, 0.4] = [0.3, 1.0, 0.6, 1.6]。

LSTM共有3种门,通过这3种门控制与保护细胞状态。

2.1、遗忘门

第一步: 通过遗忘门过滤掉不想要的信息;

遗忘门决定遗忘哪些信息,它的作用就是遗忘掉老的不用的旧的信息,遗忘门接收上一时刻输出信息h t − 1和当前时刻的输入x t ,然后输出遗忘矩阵f t 决定上一时刻细胞状态C t − 1 的通过状态。
让我们回到语言模型的示例,该模型试图根据所有先前的单词来预测下一个单词。 在这样的问题中,细胞状态可能包括当前受试者的性别,从而可以使用正确的代词。 看到新主语时,我们想忘记旧主语的性别。

左侧的ht-1和下面输入的xt经过了连接操作,再通过一个线性单元,经过一个o也就是sigmoid函数
生成一个0到1之间的数字作为系数输出,表达式如上,Wf和bf作为待定系数是要进行训练学习的。

2.2、输入门

第二步: 决定从新的信息中存储哪些信息到细胞状态中去。即产生要更新的信息。

 包含两个小的神经网络层,一个是熟悉的sigmoid部分:

第三步: 更新细胞状态

2.3、输出门 

第四步: 基于细胞状态,确定输出信息

首先利用输出门(sigmoid层)产生一个输出矩阵Ot,决定输出当前状态Ct的哪些部分。接着状态
Ct通过tanh层之后与Ot相乘,成为输出的内容ht。
一个输出到同层下一个单元,一个输出到下一层的单元上,首先,我们运行一个sigmoid层来确定
细胞状态的哪个部分将输出出去。

接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。

在语言模型中,这种影响是可以影响前后词之间词形的相关性的,例如前面输入的是一个代词或名词,后面跟随的动词会学到是否使用“三单形式”或根据前面输入的名词数量来决定输出的是单数形式还是复数形式。

三、案例,代码如下

import torch
from torch import nn

num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5

idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # hello
y_data = [3, 1, 2, 3, 2]  # ohlol

inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)

# class LSTM(nn.Module):
#     def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
#         super().__init__()
#         self.input_size = input_size
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.output_size = output_size
#         self.num_directions = 1 # 单向LSTM
#         self.batch_size = batch_size
#         self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
#         self.linear = nn.Linear(self.hidden_size, self.output_size)
#
#     def forward(self, input_seq):
#         batch_size, seq_len = input_seq.shape[0], input_seq.shape[1]
#         h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
#         c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
#         # output(batch_size, seq_len, num_directions * hidden_size)
#         output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
#         pred = self.linear(output)  # (5, 30, 1)
#         pred = pred[:, -1, :]  # (5, 1)
#         return pred

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.num_directions = 1 # 单向LSTM
        self.emb = torch.nn.Embedding(input_size, embedding_size)  # 嵌入层
        self.lstm=torch.nn.LSTM(input_size=embedding_size,
                                hidden_size=hidden_size,
                                num_layers=num_layers,
                                batch_first=True)
        # self.rnn = torch.nn.RNN(input_size=embedding_size,
        #                         hidden_size=hidden_size,
        #                         num_layers=num_layers,
        #                         batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, num_class)

    def forward(self, x):
        h_0 = torch.zeros(self.num_directions*num_layers, x.size(0), hidden_size)  # 构造h0
        c_0 = torch.zeros(self.num_directions * num_layers, x.size(0), hidden_size)
        x = self.emb(x)  # 把长整型转变成嵌入层稠密的向量模式
        x, _ = self.lstm(x, (h_0, c_0))
        x = self.fc(x)
        return x.view(-1, num_class)


net = Model()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)

for epoch in range(15):
    optimizer.zero_grad()  # 优化器归零
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()  # 反向传播
    optimizer.step()  # 优化器更新

    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/15] loss=%.3f ' % (epoch + 1, loss.item()))

运行结果如下:

参考文献:

https://blog.csdn.net/two_apples/article/details/105150848?ops_request_misc=&request_id=&biz_id=102&utm_term=lstm&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-5-105150848.142^v42^new_blog_pos_by_title,185^v2^tag_show&spm=1018.2226.3001.4187

相关文章:

  • 外贸怎么在谷歌搜索客户?
  • L73.linux命令每日一练 -- 第十章 Linux网络管理命令 -- dig和host
  • 用MicroPython开发ESP32-用TFT-LCD(ST7735S)显示图像
  • off-by-one+overlapped chunk
  • Debian/Ubuntu/Kali 如何安装 Spotify 音乐白嫖神器
  • Vue-Vue实例
  • JVM外部调试工具:JMXTerm
  • super和this的区别
  • 为什么软件工程项目普遍不重视可行性分析?
  • 亚马逊云购买和配置苹果MacOs系统的云主机
  • springboot++vue+elementui网上零食购物商城网站系统带统计投诉java
  • SSM学生惩奖系统的设计与实现毕业设计-附源码201520
  • MySQL性能优化Buffer Pool详细介绍
  • 前端年终总结
  • 如何图片批量重命名编号不要汉字?
  • mac修复ab及siege安装
  • MySQL用户中的%到底包不包括localhost?
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • Yeoman_Bower_Grunt
  • 阿里云购买磁盘后挂载
  • 记录一下第一次使用npm
  • 记一次用 NodeJs 实现模拟登录的思路
  • 开源地图数据可视化库——mapnik
  • 类orAPI - 收藏集 - 掘金
  • 网络应用优化——时延与带宽
  • 小程序上传图片到七牛云(支持多张上传,预览,删除)
  • raise 与 raise ... from 的区别
  • 阿里云重庆大学大数据训练营落地分享
  • ​七周四次课(5月9日)iptables filter表案例、iptables nat表应用
  • # Pytorch 中可以直接调用的Loss Functions总结:
  • # 日期待t_最值得等的SUV奥迪Q9:空间比MPV还大,或搭4.0T,香
  • #{}和${}的区别?
  • $.ajax()方法详解
  • (3)STL算法之搜索
  • (大众金融)SQL server面试题(1)-总销售量最少的3个型号的车及其总销售量
  • (官网安装) 基于CentOS 7安装MangoDB和MangoDB Shell
  • (免费领源码)python#django#mysql公交线路查询系统85021- 计算机毕业设计项目选题推荐
  • (三) prometheus + grafana + alertmanager 配置Redis监控
  • (三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)
  • (实战篇)如何缓存数据
  • (已解决)报错:Could not load the Qt platform plugin “xcb“
  • (转)3D模板阴影原理
  • .NET 编写一个可以异步等待循环中任何一个部分的 Awaiter
  • .NET 应用启用与禁用自动生成绑定重定向 (bindingRedirect),解决不同版本 dll 的依赖问题
  • .NET企业级应用架构设计系列之结尾篇
  • @GetMapping和@RequestMapping的区别
  • [ vulhub漏洞复现篇 ] Apache APISIX 默认密钥漏洞 CVE-2020-13945
  • [ 第一章] JavaScript 简史
  • []串口通信 零星笔记
  • [BZOJ4566][HAOI2016]找相同字符(SAM)
  • [C#7] 1.Tuples(元组)
  • [C++基础]-初识模板
  • [CSS]浮动
  • [Go WebSocket] 多房间的聊天室(五)用多个小锁代替大锁,提高效率
  • [HackMyVM]靶场Boxing