当前位置: 首页 > news >正文

自然语言处理(一):RNN

「循环神经网络」(Recurrent Neural Network,RNN)是一个非常经典的面向序列的模型,可以对自然语言句子或是其他时序信号进行建模。进一步讲,它只有一个物理RNN单元,但是这个RNN单元可以按照时间步骤进行展开,在每个时间步骤接收当前时间步的输入和上一个时间步的输出,然后进行计算得出本时间步的输出。

Why

  1. CNN 需要固定长度的输入、输出,RNN 的输入和输出可以是不定长且不等长的
  2. CNN 只有 one-to-one 一种结构,而 RNN 有多种结构,如下图:
alt

Model

  • 简单模型示例

    alt

循环神经网络的隐藏层的值s不仅仅取决于当前这次的输入x,还取决于上一次隐藏层的值s。「权重矩阵」 W就是「隐藏层」上一次的值作为这一次的输入的权重。

  • RNN时间线展开
alt

时刻的输入,不仅是 ,还应该包括上一个时刻所计算的

  • 使用公式表示
alt

示例

下面我们举个例子来讨论一下,如图所示,假设我们现在有这样一句话:”我爱人工智能”,经过分词之后变成”我,爱,人工,智能”这4个单词,RNN会根据这4个单词的时序关系进行处理,在第1个时刻处理单词”我”,第2个时刻处理单词”爱”,依次类推。

alt

从图上可以看出,RNN在每个时刻 均会接收两个输入,一个是当前时刻的单词 ,一个是来自上一个时刻的输出 ,经过计算后产生当前时刻的输出 。例如在第2个时刻,它的输入是”爱”和 ,它的输出是 ;在第3个时刻,它的输入是”人工”和 , 输出是 ,依次类推,直到处理完最后一个单词。

总结一下,RNN会从左到右逐词阅读这个句子,并不断调用一个相同的RNN Cell来处理时序信息,每阅读一个单词,RNN首先将本时刻 的单词 和这个模型内部记忆的「状态向量」 融合起来,形成一个带有最新记忆的状态向量

  • 「Tip」:当RNN读完最后一个单词后,那RNN就已经读完了整个句子,一般可认为最后一个单词输出的状态向量能够表示整个句子的语义信息,即它是整个句子的语义向量,这是一个常用的想法。

Code

  • 数据准备
import torch
import torch.nn as nn
import numpy as np

torch.manual_seed(0)  # 设置随机种子以实现可重复性

seq_length = 5
input_size = 1
hidden_size = 10
output_size = 1
batch_size = 1

time_steps = np.linspace(0, np.pi, 100)
data = np.sin(time_steps)
data.resize((len(time_steps), 1))

# Split data into sequences of length 5
x = []
y = []
for i in range(len(data)-seq_length):
    _x = data[i:i+seq_length]
    _y = data[i+seq_length]
    x.append(_x)
    y.append(_y)

x = np.array(x)
y = np.array(y)
  • Model
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        out, hidden = self.rnn(x, hidden)
        out = out.view(-1, self.hidden_size)
        out = self.fc(out)
        return out, hidden
  • Train
model = RNN(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    total_loss = 0
    hidden = None
    for i in range(len(x)):
        optimizer.zero_grad()
        input_ = torch.Tensor(x[i]).unsqueeze(0)
        target = torch.Tensor(y[i])
        output, hidden = model(input_, hidden)
        hidden = hidden.detach()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {total_loss}')

缺点

  • 当阅读很长的序列时,网络内部的信息会逐渐变得越来越复杂,以至于超过网络的记忆能力,使得最终的输出信息变得混乱无用。

参考

  1. https://zhuanlan.zhihu.com/p/30844905
  2. https://paddlepedia.readthedocs.io/en/latest/tutorials/sequence_model/rnn.html
  3. https://saturncloud.io/blog/building-rnn-from-scratch-in-pytorch/
  4. https://pytorch.org/docs/stable/generated/torch.nn.RNN.html

本文由 mdnice 多平台发布

相关文章:

  • 【TiDB】TiDB CLuster部署
  • Adobe家里那点事儿~~~
  • Django(四、路由层)
  • 一种艺术风格的神经算法:总结与实现
  • 【KVM-6】KVM/QEMU软件栈
  • Three.js 实现简单的PCD加载器(可从本地读取pcd文件)【附完整代码】
  • 短剧软件APP开发方案
  • Django知识点
  • C#中匿名类的声明及使用
  • vuejs - - - - - 移动端设备兼容(pxtorem)
  • QT QDockWidget
  • C++语言的广泛应用领域
  • arcgis基础篇--实验
  • 数据分析实战 | K-means算法——蛋白质消费特征分析
  • 计算机网络第一章(计算机网络开篇)
  • [iOS]Core Data浅析一 -- 启用Core Data
  • Effective Java 笔记(一)
  • flutter的key在widget list的作用以及必要性
  • git 常用命令
  • input实现文字超出省略号功能
  • javascript从右向左截取指定位数字符的3种方法
  • JavaScript设计模式与开发实践系列之策略模式
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • java多线程
  • JS函数式编程 数组部分风格 ES6版
  • MQ框架的比较
  • Objective-C 中关联引用的概念
  • SpringCloud(第 039 篇)链接Mysql数据库,通过JpaRepository编写数据库访问
  • Terraform入门 - 1. 安装Terraform
  • webpack4 一点通
  • 对JS继承的一点思考
  • 关于springcloud Gateway中的限流
  • 互联网大裁员:Java程序员失工作,焉知不能进ali?
  • 简单数学运算程序(不定期更新)
  • 前端临床手札——文件上传
  • AI算硅基生命吗,为什么?
  • linux 淘宝开源监控工具tsar
  • zabbix3.2监控linux磁盘IO
  • 积累各种好的链接
  • ​渐进式Web应用PWA的未来
  • !! 2.对十份论文和报告中的关于OpenCV和Android NDK开发的总结
  • #### go map 底层结构 ####
  • #stm32整理(一)flash读写
  • $jQuery 重写Alert样式方法
  • $refs 、$nextTic、动态组件、name的使用
  • (2)MFC+openGL单文档框架glFrame
  • (C语言)共用体union的用法举例
  • (Python) SOAP Web Service (HTTP POST)
  • (带教程)商业版SEO关键词按天计费系统:关键词排名优化、代理服务、手机自适应及搭建教程
  • (附源码)springboot猪场管理系统 毕业设计 160901
  • (南京观海微电子)——I3C协议介绍
  • (十二)devops持续集成开发——jenkins的全局工具配置之sonar qube环境安装及配置
  • (未解决)macOS matplotlib 中文是方框
  • (一)80c52学习之旅-起始篇
  • (一)认识微服务