当前位置: 首页 > news >正文

nn.LSTM个人记录

简介

 

nn.LSTM参数

torch.nn.lstm(input_size,   "输入的嵌入向量维度,例如每个单词用50维向量表示,input_size就是50"hidden_size,  "隐藏层节点数量,也是输出的嵌入向量维度"num_layers,   "lstm 隐层的层数,默认为1"bias,         "隐层是否带 bias,默认为 true"batch_first,  "True 或者 False,如果是 True,则 input 为(batchsize, len, input_size),默认值为:False(len, batchsize, input_size)"dropout,      "除最后一层,每一层的输出都进行dropout,默认值0"bidirectional "如果设置为 True, 则表示双向 LSTM,默认为 False")

维度

batch_first=True,输入维度(batchsize,len,input_size)

batch_first=False,输入维度(len,batchsize, input_size)

batch_first=False,输出维度(len,batchsize,hidden_size)

举例嵌入向量维度为1

 假如输入x为(batchsize,len)的序列,即嵌入向量维度为1,进行一个回归预测。

如果将嵌入向量维度维度设为1就不太合理,因为如果len非常长例如几w,那么经过几w的时间步得到的得到的h维度为(batchsize,1),序列太长丢失很多信息,再输入全连接层预测效果不好。并且lstm实际上将嵌入向量维度从input_size规约到hidden_size。

所以在这里我们将len作为input_size,嵌入向量维度1作为len(即对调了一下)

添加一个维度:

x = x.unsqueeze(0)

x维度变为(1,batchsize,len),相当于设置数据的长度为1,嵌入向量维度为len,通过nn.LSTM输入到网络中。

#lstm为定义的网络
#h[-1]为最后输入到全连接层的嵌入矩阵 但是由于此问题中len为1,所以x等于h[-1]
x, (h, c) = lstm(x)

x维度变为(1,batchsize,hidden_size)

h为每层lstm最后一个时间步的输出一般可以输入到后续的全连接层),维度为(num_layers,batchsize,hidden_size)

c为最后一个时间步 LSTM cell 的状态(记忆单元,一般用不到),维度为(num_layers,batchsize,hidden_size)

移除张量中所有尺寸为 1 的维度,即将第一个维度移除掉:

lstm_out = x.squeeze(0)

x维度变为(batchsize,hidden_size) ,输入到全连接层(线性层,维度(hidden_size,num_class))中,最终输出维度(batchsize,num_class)

参考:

Pytorch — LSTM (nn.LSTM & nn.LSTMCell)-CSDN博客

相关文章:

  • Mybatis 日志
  • 【Bootstrap学习 day1】
  • 【交叉编译环境】安装arm-linux交叉编译环境到虚拟机教程(简洁版本)
  • 实现文字超过显示宽度每间隔1s自动向左滚动显示(原生JS和vue两种实现方式)
  • SLF4J: Class path contains multiple SLF4J bindings.解决
  • SpringBoot整合Mybatis遇到的常见问题及解决方案
  • 【Midjourney】Midjourney根据prompt提示词生成人物图片
  • 【Linux】修复 Linux 错误 - 文件过大
  • java freemarker 动态生成excel文件
  • 【leetcode150】逆波兰表达式求值Java代码讲解
  • vue大屏-列表自动滚动vue-seamless-scroll
  • mysql二进制对应ef中实体表字段类型
  • git 学习 之一个规范的 commit 如何写
  • 构建创新学习体验:企业培训系统技术深度解析
  • 【Java EE初阶四】锁及synchronized关键字
  • Akka系列(七):Actor持久化之Akka persistence
  • - C#编程大幅提高OUTLOOK的邮件搜索能力!
  • Consul Config 使用Git做版本控制的实现
  • Cumulo 的 ClojureScript 模块已经成型
  • JavaScript函数式编程(一)
  • JDK9: 集成 Jshell 和 Maven 项目.
  • leetcode-27. Remove Element
  • mongodb--安装和初步使用教程
  • react-native 安卓真机环境搭建
  • SAP云平台运行环境Cloud Foundry和Neo的区别
  • Stream流与Lambda表达式(三) 静态工厂类Collectors
  • 半理解系列--Promise的进化史
  • 道格拉斯-普克 抽稀算法 附javascript实现
  • 浮现式设计
  • 函数式编程与面向对象编程[4]:Scala的类型关联Type Alias
  • 猴子数据域名防封接口降低小说被封的风险
  • 基于Dubbo+ZooKeeper的分布式服务的实现
  • 基于遗传算法的优化问题求解
  • 极限编程 (Extreme Programming) - 发布计划 (Release Planning)
  • 聊聊redis的数据结构的应用
  • 原创:新手布局福音!微信小程序使用flex的一些基础样式属性(一)
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • (12)Linux 常见的三种进程状态
  • (HAL库版)freeRTOS移植STMF103
  • (Redis使用系列) Springboot 使用redis实现接口Api限流 十
  • (ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY)讲解
  • (附源码)php投票系统 毕业设计 121500
  • (附源码)ssm基于web技术的医务志愿者管理系统 毕业设计 100910
  • (附源码)ssm学生管理系统 毕业设计 141543
  • (转)Mysql的优化设置
  • (转)母版页和相对路径
  • .net 4.0 A potentially dangerous Request.Form value was detected from the client 的解决方案
  • .net websocket 获取http登录的用户_如何解密浏览器的登录密码?获取浏览器内用户信息?...
  • .net 无限分类
  • .net之微信企业号开发(一) 所使用的环境与工具以及准备工作
  • @EventListener注解使用说明
  • @param注解什么意思_9000字,通俗易懂的讲解下Java注解
  • @Transaction注解失效的几种场景(附有示例代码)
  • [AIGC] Kong:一个强大的 API 网关和服务平台
  • [BUUCTF NewStarCTF 2023 公开赛道] week3 crypto/pwn