当前位置: 首页 > news >正文

PyTorch快速入门教程五(rnn)

为什么80%的码农都做不了架构师?>>>   hot3.png

上一讲讲了cnn以及如何使用pytorch实现简单的多层卷积神经网络,下面我们将进入rnn,关于rnn将分成三个部分,

  1. 介绍rnn的基本结构以及在pytorch里面api的各个参数所表示的含义,
  2. 介绍rnn如何在MNIST数据集上做分类,
  3. 涉及一点点自然语言处理的东西。

RNN

首先介绍一下什么是rnnrnn特别擅长处理序列类型的数据,因为他是一个循环的结构

pytorch Rnn

一个序列的数据依次进入网络A,网络A循环的往后传递。

这就是RNN的基本结构类型。而最早的RNN模型,序列依次进入网络中,之前进入序列的数据会保存信息而对后面的数据产生影响,所以RNN有着记忆的特性,而同时越前面的数据进入序列的时间越早,所以对后面的数据的影响也就越弱,简而言之就是一个数据会更大程度受到其临近数据的影响。但是我们很有可能需要更长时间之前的信息,而这个能力传统的RNN特别弱,于是有了LSTM这个变体。

LSTM

pytorch LSTM

这就是LSTM的模型结构,也是一个向后传递的链式模型,而现在广泛使用的RNN其实就是LSTM,序列中每个数据传入LSTM可以得到两个输出,而这两个输出和序列中下一个数据一起又作为传入LSTM的输入,然后不断地循环向后,直到序列结束。

下面结合pytorch一步一步来看数据传入LSTM是怎么运算的

首先需要定义好LSTM网络,需要nn.LSTM(),首先介绍一下这个函数里面的参数

  1. input_size 表示的是输入的数据维数

  2. hidden_size 表示的是输出维数

  3. num_layers 表示堆叠几层的LSTM,默认是1

  4. bias True 或者 False,决定是否使用bias

  5. batch_first True 或者 False,因为nn.lstm()接受的数据输入是(序列长度,batch,输入维数),这和我们cnn输入的方式不太一致,所以使用batch_first,我们可以将输入变成(batch,序列长度,输入维数)

  6. dropout 表示除了最后一层之外都引入一个dropout

  7. bidirectional 表示双向LSTM,也就是序列从左往右算一次,从右往左又算一次,这样就可以两倍的输出

是网络的输出维数,比如M,因为输出的维度是M,权重w的维数就是(M, M)和(M, K),b的维数就是(M, 1)和(M, 1),最后经过sigmoid激活函数,得到的f的维数是(M, 1)。 pytorch bidirectional

对于第一个数据,需要定义初始的h_0和c_0,所以nn.lstm()的输入Inputs:input, (h_0, c_0),表示输入的数据以及h_0和c_0,这个可以自己定义,如果不定义,默认就是0

pytorch

第二步也是差不多的操作,只不多是另外两个权重加上不同的激活函数,一个使用的是sigmoid,一个使用的是tanh,得到的输出i_t和\tilde{C}_t都是(M, 1)。

pytorch

接着这个乘法是矩阵每个位置对应相乘,然后将两个矩阵加起来,得到的输出C_t是(M, 1)。

pytorch 最后一步得到的o_t也是(M, 1),然后C_t经过激活函数tanh,再和o_t每个位置相乘,得到的输出h_t也是(M, 1)。

最后得到的输出就是h_t和C_t,维数分别都是(M, 1),而输入x_t 维数都是(K, 1)。

lstm = nn.LSTM(10, 30, batch_first=True)

可以通过这样定义一个一层的LSTM输入是10,输出是30

lstm.weight_hh_l0.size()
lstm.weight_ih_l0.size()
lstm.bias_hh_l0.size()
lstm.bias__ih_l0.size()

可以分别得到权重的维数,注意之前我们定义的4个weights被整合到了一起,比如这个lstm,输入是10维,输出是30维,相对应的weight就是30x10,这样的权重有4个,然后pytorch将这4个组合在了一起,方便表示,也就是lstm.weight_ih_l0,所以它的维数就是120x10

我们定义一个输入

x = Variable(torch.randn((50, 100, 10)))
h0 = Variable(torch.randn(1, 50, 30))
c0 = Variable(torch.randn(1, 50 ,30))

x的三个数字分别表示batch_size为50,序列长度为100,每个数据维数为10

h0的第二个参数表示batch_size为50,输出维数为30,第一个参数取决于网络层数和是否是双向的,如果双向需要乘2,如果是多层,就需要乘以网络层数

c0的三个参数和h0是一致的

out, (h_out, c_out) = lstm(x, (h0, c0))

这样就可以得到网络的输出了,和上面讲的一致,另外如果不传入h0和c0,默认的会传入相同维数的0矩阵

这就是我们如何在pytorch上使用RNN的基本操作了,了解完最基本的参数我们才能够使用其来做应用。

在这里,我整理发布了Pytorch中文文档,方便大家查询使用,同时也准备了中文论坛,欢迎大家学习交流!

Pytorch中文文档

Pytorch中文论坛

Pytorch中文文档已经发布,完美翻译,更加方便大家浏览:

Pytorch中文网:https://ptorch.com/

Pytorch中文文档:https://ptorch.com/docs/1/

转载于:https://my.oschina.net/earnp/blog/1113894

相关文章:

  • 故障排查
  • 腾讯云服务器 安装监控组件
  • CRM系统客户形成需求和认知的五大因素
  • 【leetcode】55. Jump Game
  • node.js 学习(二)
  • 内华达州PUC特准3.2万光伏用户优惠太阳能补贴费率
  • 文件读,写,拷贝,删除
  • 神州数码网真解决方案助山西电力信息高速化
  • 大数据正在改变企业决策方式
  • Centos 7 配置tomcat服务器
  • 常用软件测试工具的分析
  • 让git更高效--文末有福利
  • 力争大数据及关联产业规模2020年达300亿元
  • python 操作asdl
  • 美国国防部推出微型RFID芯片,助力电子零件防伪
  • 《Javascript数据结构和算法》笔记-「字典和散列表」
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • JavaScript DOM 10 - 滚动
  • Javascript 原型链
  • Java编程基础24——递归练习
  • JAVA之继承和多态
  • Spring Cloud Alibaba迁移指南(一):一行代码从 Hystrix 迁移到 Sentinel
  • SQL 难点解决:记录的引用
  • ubuntu 下nginx安装 并支持https协议
  • 闭包--闭包作用之保存(一)
  • 规范化安全开发 KOA 手脚架
  • 精益 React 学习指南 (Lean React)- 1.5 React 与 DOM
  • 前端_面试
  • 前端性能优化--懒加载和预加载
  • 使用 @font-face
  • 吴恩达Deep Learning课程练习题参考答案——R语言版
  • 新版博客前端前瞻
  • ​ArcGIS Pro 如何批量删除字段
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • ​软考-高级-系统架构设计师教程(清华第2版)【第1章-绪论-思维导图】​
  • (1)SpringCloud 整合Python
  • (3)选择元素——(14)接触DOM元素(Accessing DOM elements)
  • (Redis使用系列) SpringBoot中Redis的RedisConfig 二
  • (附源码)springboot人体健康检测微信小程序 毕业设计 012142
  • (生成器)yield与(迭代器)generator
  • (一) storm的集群安装与配置
  • (转)ORM
  • (转)Spring4.2.5+Hibernate4.3.11+Struts1.3.8集成方案一
  • .bat批处理(九):替换带有等号=的字符串的子串
  • .md即markdown文件的基本常用编写语法
  • .NET Core/Framework 创建委托以大幅度提高反射调用的性能
  • .NET Framework 的 bug?try-catch-when 中如果 when 语句抛出异常,程序将彻底崩溃
  • .net mvc 获取url中controller和action
  • .net Signalr 使用笔记
  • .NET 中 GetHashCode 的哈希值有多大概率会相同(哈希碰撞)
  • .NetCore实践篇:分布式监控Zipkin持久化之殇
  • .Net各种迷惑命名解释
  • .net快速开发框架源码分享
  • @reference注解_Dubbo配置参考手册之dubbo:reference
  • @RequestParam @RequestBody @PathVariable 等参数绑定注解详解