当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上6.7节
此节功能为:门控循环单元(GRU)的从零实现
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 6.7 门控循环单元(GRU)
# 注释:黄文俊
# E-mail:hurri_cane@qq.com


import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()


num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)
    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))

    W_xz, W_hz, b_z = _three()  # 更新门参数
    W_xr, W_hr, b_r = _three()  # 重置门参数
    W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数

    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])


def init_gru_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )



# 重置门有助于捕捉时间序列里短期的依赖关系。
# 更新门有助于捕捉时间序列里长期的依赖关系。

# 据门控循环单元的计算表达式定义模型
def gru(inputs, state, params):
    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        H_tilda = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R * H, W_hh) + b_h)
        H = Z * H + (1 - Z) * H_tilda
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)




num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']


d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, False, num_epochs, num_steps, lr,
                          clipping_theta, batch_size, pred_period, pred_len,
                          prefixes)





print("*"*50)

相关文章:

  • MEDC2007北京游记 - WindowsMobile Ophone
  • 《动手学深度学习》(PyTorch版)代码注释 - 35 【GRU_with_simple_way】
  • 祝贺CICI拿到VISA
  • 《动手学深度学习》(PyTorch版)代码注释 - 36 【LSTM_with_zero】
  • WPF/E去了,Silverlight来了
  • iPhone - 少一点自恋,多一点现实 !
  • 《动手学深度学习》(PyTorch版)代码注释 - 37 【LSTM_with_simple_way】
  • FerryMan Fractal的: 23312506
  • 《动手学深度学习》(PyTorch版)代码注释 - 38 【Gradient_descent_Learning】
  • 鲁迅先生
  • 《动手学深度学习》(PyTorch版)代码注释 - 39 【Small_batch_stochastic_gradient_descent】
  • 《动手学深度学习》(PyTorch版)代码注释 - 40 【Momentum_method】
  • 《动手学深度学习》(PyTorch版)代码注释 - 41 【AdaGrad_algorithm】
  • 《动手学深度学习》(PyTorch版)代码注释 - 42 【RMSProp_algorithm】
  • 《动手学深度学习》(PyTorch版)代码注释 - 43 【AdaDelta_algorithm】
  • js正则,这点儿就够用了
  • MySQL主从复制读写分离及奇怪的问题
  • Node项目之评分系统(二)- 数据库设计
  • node学习系列之简单文件上传
  • Spring Cloud中负载均衡器概览
  • TCP拥塞控制
  • Web标准制定过程
  • win10下安装mysql5.7
  • XForms - 更强大的Form
  • 聊聊flink的BlobWriter
  • 微信小程序设置上一页数据
  • 赢得Docker挑战最佳实践
  • 掌握面试——弹出框的实现(一道题中包含布局/js设计模式)
  • Oracle Portal 11g Diagnostics using Remote Diagnostic Agent (RDA) [ID 1059805.
  • const的用法,特别是用在函数前面与后面的区别
  • ​configparser --- 配置文件解析器​
  • ​渐进式Web应用PWA的未来
  • #基础#使用Jupyter进行Notebook的转换 .ipynb文件导出为.md文件
  • #我与Java虚拟机的故事#连载16:打开Java世界大门的钥匙
  • (补)B+树一些思想
  • (二)hibernate配置管理
  • (附源码)ssm高校实验室 毕业设计 800008
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)
  • (十一)JAVA springboot ssm b2b2c多用户商城系统源码:服务网关Zuul高级篇
  • (原创)Stanford Machine Learning (by Andrew NG) --- (week 9) Anomaly DetectionRecommender Systems...
  • (转)h264中avc和flv数据的解析
  • (转)Oracle 9i 数据库设计指引全集(1)
  • (轉貼) 寄發紅帖基本原則(教育部禮儀司頒布) (雜項)
  • .NET DevOps 接入指南 | 1. GitLab 安装
  • .Net Memory Profiler的使用举例
  • .NET Standard 的管理策略
  • .net 桌面开发 运行一阵子就自动关闭_聊城旋转门家用价格大约是多少,全自动旋转门,期待合作...
  • .NET/C# 推荐一个我设计的缓存类型(适合缓存反射等耗性能的操作,附用法)
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)
  • .Net环境下的缓存技术介绍
  • .NET简谈互操作(五:基础知识之Dynamic平台调用)
  • .Net下的签名与混淆
  • .net中调用windows performance记录性能信息
  • ::