当前位置: 首页 > news >正文

pytorch笔记 GRUCELL

1 介绍

GRU的一个单元

2 基本使用方法

torch.nn.GRUCell(input_size, hidden_size, bias=True, device=None, dtype=None)

输入:(batch,input_size) 

输出和隐藏层:(batch,hidden_size)

3 举例

import torch.nn as nnrnn = nn.GRUCell(input_size=5,hidden_size=10)input_x = torch.randn(3, 5)
#batch,input_sizeh0 = torch.randn(3, 10)
#batch,hidden_sizeoutput= rnn(input_x, h0)
output.shape, output
'''
(torch.Size([3, 10]),tensor([[-0.4414,  1.0060,  0.3346, -0.2446, -0.4170, -0.6201, -1.0049,  0.1765,0.2238, -2.0249],[ 0.2764,  0.6327,  0.1682, -0.0433,  1.2226, -1.0959,  0.0345, -0.6375,-1.4599, -0.3670],[ 0.9447, -0.0849,  0.3983, -0.4078,  0.9805, -0.1826,  0.2151,  0.3382,-0.1147, -0.2307]], grad_fn=<AddBackward0>))
'''

4 和GRU的异同

功能性
  • GRU: 它是一个完整的循环层,可以处理整个序列的输入,并一次性返回整个序列的输出。
  • GRUCell: 它处理单个时间步长的输入,并返回单个时间步长的输出。它更为基础,通常在你想自定义循环过程时使用。
输入:
  • GRU: 期望的输入形状为 (seq_len, batch, input_size)(如果 batch_first=True,则为 (batch, seq_len, input_size))。
  • GRUCell: 期望的输入形状为 (batch, input_size)
输出:
  • GRU: 它返回两个输出 —— 整个序列的输出和最后一个时间步长的隐藏状态。输出的形状为 (seq_len, batch, hidden_size)(num_layers * num_directions, batch, hidden_size)
  • GRUCell: 它只返回下一个时间步长的隐藏状态,其形状为 (batch, hidden_size)
用法:
  • 使用 GRU 时,你可以一次性将整个序列传入,而不需要自己编写循环。
  • 使用 GRUCell 时,你需要手动编写循环,以一个时间步长为单位处理输入。
应用场景:
  • GRU: 当你想使用标准的循环过程处理整个序列时,通常使用GRU。
  • GRUCell: 当你想自定义循环过程或有特定的需求时使用,例如混合不同类型的RNN单元或在循环中执行特定操作。

5 一个GRU由几个GRUcell组成?

一个具有 seq_lenbidirectional=True 和指定的 num_layers 的 GRU 对应的 GRUCell 的数量为:

  1. seq_len:对于长度为 seq_len 的输入序列,GRU 在内部会进行 seq_len 次循环操作,每次循环处理序列中的一个时间步长。所以这部分会贡献 seq_len 个 GRUCell。

  2. bidirectional=True:当 GRU 是双向的,即 bidirectional=True,那么对于每一个时间步长,都会有两个 GRUCell 被调用:一个是正向的,另一个是反向的。因此,双向性将 GRUCell 的数量增加一倍。

  3. num_layers:这表示你要堆叠多少层的 GRU。每一层都会为每个时间步调用其自己的 GRUCell(考虑到双向性,这可能是两个)。所以如果你有 num_layers 层,那么你需要乘以这个数字。

综上所述,总的 GRUCell 的数量为: Total GRUCells=seq_len×(2 if bidirectional else 1)×num_layers

相关文章:

  • mediasoup-cluster横向扩容机制
  • mac flutter pb解析报错:protoc-gen-dart: program not found or is not executable
  • 蓝桥杯官网练习题(正则问题)
  • openGauss学习笔记-114 openGauss 数据库管理-设置安全策略-设置帐号有效期
  • gcc: __linux__
  • [SSD综述 1.4] SSD固态硬盘的架构和功能导论
  • Julia文件读写函数:write和read
  • 无mac电脑获取app的公钥的方法
  • IOC容器中的Bean是线程安全的吗?
  • 【jvm】虚拟机栈
  • 好物周刊#29:项目管理软件
  • vector类模拟实现(c++)(学习笔记)
  • 【C语言】【数据结构】【顺序表】
  • 二维码智慧门牌管理系统升级:一键报错解决三大问题
  • 电源管理(PMIC)MAX20428ATIA/VY、MAX20428ATIC/VY、MAX20428ATIE/VY适合汽车ADAS应用的开关稳压器
  • 【面试系列】之二:关于js原型
  • 2017-08-04 前端日报
  • 30天自制操作系统-2
  • es6(二):字符串的扩展
  • golang中接口赋值与方法集
  • input的行数自动增减
  • Intervention/image 图片处理扩展包的安装和使用
  • Javascript编码规范
  • js学习笔记
  • linux安装openssl、swoole等扩展的具体步骤
  • Making An Indicator With Pure CSS
  • python 装饰器(一)
  • RxJS: 简单入门
  • spring + angular 实现导出excel
  • ucore操作系统实验笔记 - 重新理解中断
  • ViewService——一种保证客户端与服务端同步的方法
  • 关于Android中设置闹钟的相对比较完善的解决方案
  • 基于axios的vue插件,让http请求更简单
  • 理解 C# 泛型接口中的协变与逆变(抗变)
  • 如何解决微信端直接跳WAP端
  • 使用API自动生成工具优化前端工作流
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 突破自己的技术思维
  • 问:在指定的JSON数据中(最外层是数组)根据指定条件拿到匹配到的结果
  • 好程序员web前端教程分享CSS不同元素margin的计算 ...
  • # centos7下FFmpeg环境部署记录
  • ( 10 )MySQL中的外键
  • (分布式缓存)Redis分片集群
  • (附源码)node.js知识分享网站 毕业设计 202038
  • (详细版)Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models
  • (转)大型网站架构演变和知识体系
  • .NET 3.0 Framework已经被添加到WindowUpdate
  • .Net Framework 4.x 程序到底运行在哪个 CLR 版本之上
  • .NET 设计模式初探
  • .Net6支持的操作系统版本(.net8已来,你还在用.netframework4.5吗)
  • .NET企业级应用架构设计系列之结尾篇
  • /bin/bash^M: bad interpreter: No such file ordirectory
  • ::
  • @EnableAsync和@Async开始异步任务支持