当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 39 【Small_batch_stochastic_gradient_descent】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上7.3节
此节功能为:小批量随机梯度下降
由于此节相对复杂,代码注释量较多

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/
# 7.3 小批量随机梯度下降
# 注释:黄文俊
# E-mail:hurri_cane@qq.com

from matplotlib import pyplot as plt
import numpy as np
import time
import torch
from torch import nn, optim
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

def get_data_ch7():  # 本函数已保存在d2lzh_pytorch包中方便以后使用
    data = np.genfromtxt('F:/PyCharm/Learning_pytorch/data/airfoil_self_noise.dat', delimiter='\t')
    data = (data - data.mean(axis=0)) / data.std(axis=0)
    return torch.tensor(data[:1500, :-1], dtype=torch.float32), \
    torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)

features, labels = get_data_ch7()
print(features.shape)

def sgd(params, states, hyperparams):
    for p in params:
        p.data -= hyperparams['lr'] * p.grad.data

# 训练函数
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch7(optimizer_fn, states, hyperparams, features, labels,
              batch_size=10, num_epochs=2):
    # 初始化模型
    net, loss = d2l.linreg, d2l.squared_loss

    w = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),
                           requires_grad=True)
    b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)

    def eval_loss():
        return loss(net(features, w, b), labels).mean().item()

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)

    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            l = loss(net(X, w, b), y).mean()  # 使用平均损失

            # 梯度清零
            if w.grad is not None:
                w.grad.data.zero_()
                b.grad.data.zero_()

            l.backward()
            optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())  # 每100个样本记录下当前训练误差
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')
    plt.show()

def train_sgd(lr, batch_size, num_epochs=2):
    train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)

# 小批量数设置为样本总数1500,迭代次数为6
train_sgd(1, 1500, 6)

# 小批量数设置为1,迭代次数为默认的2
# 虽然迭代次数只有2,但是记录误差的次数比上面小批量数设置为样本总数1500,迭代次数为6多
# 因为误差录取采用了过100个样本记一次的方法
# if (batch_i + 1) * batch_size % 100 == 0:
#     ls.append(eval_loss())  # 每100个样本记录下当前训练误差
train_sgd(0.005, 1)



# 简洁实现
# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,
                    batch_size=10, num_epochs=2):
    # 初始化模型
    net = nn.Sequential(
        nn.Linear(features.shape[-1], 1)
    )
    loss = nn.MSELoss()
    optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)

    def eval_loss():
        return loss(net(features).view(-1), labels).item() / 2

    ls = [eval_loss()]
    data_iter = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)

    for _ in range(num_epochs):
        start = time.time()
        for batch_i, (X, y) in enumerate(data_iter):
            # 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2
            l = loss(net(X).view(-1), y) / 2

            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            if (batch_i + 1) * batch_size % 100 == 0:
                ls.append(eval_loss())
    # 打印结果和作图
    print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))
    d2l.set_figsize()
    d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)
    d2l.plt.xlabel('epoch')
    d2l.plt.ylabel('loss')
    plt.show()


train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)


print("*"*50)

相关文章:

  • 《动手学深度学习》(PyTorch版)代码注释 - 40 【Momentum_method】
  • 《动手学深度学习》(PyTorch版)代码注释 - 41 【AdaGrad_algorithm】
  • 《动手学深度学习》(PyTorch版)代码注释 - 42 【RMSProp_algorithm】
  • 《动手学深度学习》(PyTorch版)代码注释 - 43 【AdaDelta_algorithm】
  • OpenGL Fractal Hill
  • 《动手学深度学习》(PyTorch版)代码注释 - 44 【Adam_algorithm】
  • Cg Browser下载
  • 《动手学深度学习》(PyTorch版)代码注释 - 45 【Image_augmentation】
  • 接天莲叶无穷碧,映日荷花别样红
  • 《动手学深度学习》(PyTorch版)代码注释 - 46 【Transfer_learning】
  • Java疑惑点解析(一)
  • 《动手学深度学习》(PyTorch版)代码注释 - 47 【Image_augmentation】
  • 《动手学深度学习》(PyTorch版)代码注释 - 48 【Multi-scale_target_detection】
  • Mathematica分形源码
  • 《动手学深度学习》(PyTorch版)代码注释 - 49 【Target_detection_data-set (Pikachu)】
  • android 一些 utils
  • JS笔记四:作用域、变量(函数)提升
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • JS专题之继承
  • MQ框架的比较
  • Redis在Web项目中的应用与实践
  • SQLServer之创建数据库快照
  • 初识 webpack
  • 代理模式
  • 前端js -- this指向总结。
  • 一些css基础学习笔记
  • 怎么将电脑中的声音录制成WAV格式
  • ​​​​​​​​​​​​​​汽车网络信息安全分析方法论
  • ​TypeScript都不会用,也敢说会前端?
  • # C++之functional库用法整理
  • $.ajax,axios,fetch三种ajax请求的区别
  • (30)数组元素和与数字和的绝对差
  • (day6) 319. 灯泡开关
  • (附源码)spring boot球鞋文化交流论坛 毕业设计 141436
  • (附源码)springboot社区居家养老互助服务管理平台 毕业设计 062027
  • (附源码)ssm基于jsp的在线点餐系统 毕业设计 111016
  • (简单有案例)前端实现主题切换、动态换肤的两种简单方式
  • (力扣记录)1448. 统计二叉树中好节点的数目
  • (万字长文)Spring的核心知识尽揽其中
  • (一)认识微服务
  • (原創) 如何動態建立二維陣列(多維陣列)? (.NET) (C#)
  • (转)JAVA中的堆栈
  • ./include/caffe/util/cudnn.hpp: In function ‘const char* cudnnGetErrorString(cudnnStatus_t)’: ./incl
  • .net core webapi 大文件上传到wwwroot文件夹
  • .NET Core 控制台程序读 appsettings.json 、注依赖、配日志、设 IOptions
  • .NET6 开发一个检查某些状态持续多长时间的类
  • @serverendpoint注解_SpringBoot 使用WebSocket打造在线聊天室(基于注解)
  • [ Linux Audio 篇 ] 音频开发入门基础知识
  • [Angular] 笔记 20:NgContent
  • [ArcPy百科]第三节: Geometry信息中的空间参考解析
  • [BZOJ]4817: [Sdoi2017]树点涂色
  • [C++]类和对象(中)
  • [CTO札记]盛大文学公司名称对联
  • [ffmpeg] x264 配置参数解析
  • [flink总结]什么是flink背压 ,有什么危害? 如何解决flink背压?flink如何保证端到端一致性?