当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 33 【RNN_with_simple_way】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上6.5节
此节功能为:循环神经网络的简洁实现
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 6.5 循环神经网络的简洁实现
# 注释:黄文俊
# E-mail:hurri_cane@qq.com

import time
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)


num_steps = 35
batch_size = 2
state = None
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new[0].shape)


# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, vocab_size):
        super(RNNModel, self).__init__()
        self.rnn = rnn_layer
        self.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1)
        self.vocab_size = vocab_size
        self.dense = nn.Linear(self.hidden_size, vocab_size)
        self.state = None

    def forward(self, inputs, state): # inputs: (batch, seq_len)
        # 获取one-hot向量表示
        X = d2l.to_onehot(inputs, self.vocab_size) # X是个list
        # a = torch.stack(X)
        Y, self.state = self.rnn(torch.stack(X), state)
        # 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出
        # 形状为(num_steps * batch_size, vocab_size)
        output = self.dense(Y.view(-1, Y.shape[-1]))
        return output, self.state



# 预测函数
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,
                      char_to_idx):
    state = None
    output = [char_to_idx[prefix[0]]] # output会记录prefix加上输出
    for t in range(num_chars + len(prefix) - 1):
        X = torch.tensor([output[-1]], device=device).view(1, 1)
        if state is not None:
            if isinstance(state, tuple): # LSTM, state:(h, c)
                state = (state[0].to(device), state[1].to(device))
            else:
                state = state.to(device)

        (Y, state) = model(X, state)
        if t < len(prefix) - 1:
            output.append(char_to_idx[prefix[t + 1]])
        else:
            output.append(int(Y.argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])



model = RNNModel(rnn_layer, vocab_size).to(device)
predict_res  = predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)

print(predict_res)
print("*"*50)


# 实现训练函数
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                corpus_indices, idx_to_char, char_to_idx,
                                num_epochs, num_steps, lr, clipping_theta,
                                batch_size, pred_period, pred_len, prefixes):
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    state = None
    for epoch in range(num_epochs):
        l_sum, n, start = 0.0, 0, time.time()
        data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样
        for X, Y in data_iter:
            if state is not None:
                # 使用detach函数从计算图分离隐藏状态, 这是为了
                # 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)
                if isinstance (state, tuple): # LSTM, state:(h, c)
                    state = (state[0].detach(), state[1].detach())
                else:
                    state = state.detach()

            (output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size)

            # Y的形状是(batch_size, num_steps),转置后再变成长度为
            # batch * num_steps 的向量,这样跟输出的行一一对应
            y = torch.transpose(Y, 0, 1).contiguous().view(-1)
            l = loss(output, y.long())

            optimizer.zero_grad()
            l.backward()
            # 梯度裁剪
            d2l.grad_clipping(model.parameters(), clipping_theta, device)
            optimizer.step()
            l_sum += l.item() * y.shape[0]
            n += y.shape[0]

        try:
            perplexity = math.exp(l_sum / n)
        except OverflowError:
            perplexity = float('inf')
        if (epoch + 1) % pred_period == 0:
            print('epoch %d, perplexity %f, time %.2f sec' % (
                epoch + 1, perplexity, time.time() - start))
            for prefix in prefixes:
                print(' -', predict_rnn_pytorch(
                    prefix, pred_len, model, vocab_size, device, idx_to_char,
                    char_to_idx))




num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                            corpus_indices, idx_to_char, char_to_idx,
                            num_epochs, num_steps, lr, clipping_theta,
                            batch_size, pred_period, pred_len, prefixes)







print("*"*50)

相关文章:

  • 《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】
  • MEDC2007北京游记 - WindowsMobile Ophone
  • 《动手学深度学习》(PyTorch版)代码注释 - 35 【GRU_with_simple_way】
  • 祝贺CICI拿到VISA
  • 《动手学深度学习》(PyTorch版)代码注释 - 36 【LSTM_with_zero】
  • WPF/E去了,Silverlight来了
  • iPhone - 少一点自恋,多一点现实 !
  • 《动手学深度学习》(PyTorch版)代码注释 - 37 【LSTM_with_simple_way】
  • FerryMan Fractal的: 23312506
  • 《动手学深度学习》(PyTorch版)代码注释 - 38 【Gradient_descent_Learning】
  • 鲁迅先生
  • 《动手学深度学习》(PyTorch版)代码注释 - 39 【Small_batch_stochastic_gradient_descent】
  • 《动手学深度学习》(PyTorch版)代码注释 - 40 【Momentum_method】
  • 《动手学深度学习》(PyTorch版)代码注释 - 41 【AdaGrad_algorithm】
  • 《动手学深度学习》(PyTorch版)代码注释 - 42 【RMSProp_algorithm】
  • 【知识碎片】第三方登录弹窗效果
  • iOS 颜色设置看我就够了
  • Java反射-动态类加载和重新加载
  • Object.assign方法不能实现深复制
  • PAT A1050
  • React-redux的原理以及使用
  • weex踩坑之旅第一弹 ~ 搭建具有入口文件的weex脚手架
  • 初识 beanstalkd
  • 高程读书笔记 第六章 面向对象程序设计
  • 扑朔迷离的属性和特性【彻底弄清】
  • 浅谈Golang中select的用法
  • 项目管理碎碎念系列之一:干系人管理
  • 一份游戏开发学习路线
  • nb
  • !!【OpenCV学习】计算两幅图像的重叠区域
  • #、%和$符号在OGNL表达式中经常出现
  • #if和#ifdef区别
  • #Linux(帮助手册)
  • ()、[]、{}、(())、[[]]命令替换
  • (floyd+补集) poj 3275
  • (轉貼) UML中文FAQ (OO) (UML)
  • ****Linux下Mysql的安装和配置
  • .NET Entity FrameWork 总结 ,在项目中用处个人感觉不大。适合初级用用,不涉及到与数据库通信。
  • .net解析传过来的xml_DOM4J解析XML文件
  • @Bean, @Component, @Configuration简析
  • @Builder用法
  • [【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器
  • [14]内置对象
  • [AIGC] MySQL存储引擎详解
  • [BUUCTF 2018]Online Tool(特详解)
  • [BZOJ3757] 苹果树
  • [CC2642r1] ble5 stacks 蓝牙协议栈 介绍和理解
  • [ffmpeg] 定制滤波器
  • [Geek Challenge 2023] web题解
  • [HTML]Web前端开发技术6(HTML5、CSS3、JavaScript )DIV与SPAN,盒模型,Overflow——喵喵画网页
  • [java/jdbc]插入数据时获取自增长主键的值
  • [Jquery] 实现鼠标移到某个对象,在旁边显示层。
  • [LeetCode 127] - 单词梯(Word Ladder)
  • [P4V]Perforce(P4V)使用教程
  • [POI2006] OKR-Periods of Words——最大周期长度(扩展最小周期长度)