当前位置: 首页 > news >正文

《动手学深度学习》(PyTorch版)代码注释 - 27 【Batch_normalization_with_zero】

目录

  • 说明
  • 配置环境
  • 此节说明
  • 代码

说明

本博客代码来自开源项目:《动手学深度学习》(PyTorch版)
并且在博主学习的理解上对代码进行了大量注释,方便理解各个函数的原理和用途

配置环境

使用环境:python3.8
平台:Windows10
IDE:PyCharm

此节说明

此节对应书本上5.10节
此节功能为:批量归一化从零实现
由于次节相对容易理解,代码注释量较少

代码

# 本书链接https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.8_mlp
# 5.10 批量归一化
# 注释:黄文俊
# E-mail:hurri_cane@qq.com


import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 判断当前模式是训练模式还是预测模式
    if not is_training:
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            # 样本均值
            var = ((X - mean) ** 2).mean(dim=0)
            # 样本均方差
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。这里我们需要保持
            # X的形状以便后面可以做广播运算
            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        # 训练模式下用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 拉伸和偏移
    return Y, moving_mean, moving_var


class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不参与求梯度和迭代的变量,全在内存上初始化成0
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false
        Y, self.moving_mean, self.moving_var = batch_norm(self.training,
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y


net = nn.Sequential(
            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
            BatchNorm(6, num_dims=4),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2), # kernel_size, stride
            nn.Conv2d(6, 16, 5),
            BatchNorm(16, num_dims=4),
            nn.Sigmoid(),
            nn.MaxPool2d(2, 2),
            d2l.FlattenLayer(),
            nn.Linear(16*4*4, 120),
            BatchNorm(120, num_dims=2),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            BatchNorm(84, num_dims=2),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )


# 对比未使用归一化的LeNet:
# net = nn.Sequential(
#             nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size
#             BatchNorm(6, num_dims=4),
#             nn.Sigmoid(),
#             nn.MaxPool2d(2, 2), # kernel_size, stride
#             nn.Conv2d(6, 16, 5),
#             BatchNorm(16, num_dims=4),
#             nn.Sigmoid(),
#             nn.MaxPool2d(2, 2),
#             d2l.FlattenLayer(),
#             nn.Linear(16*4*4, 120),
#             BatchNorm(120, num_dims=2),
#             nn.Sigmoid(),
#             nn.Linear(120, 84),
#             BatchNorm(84, num_dims=2),
#             nn.Sigmoid(),
#             nn.Linear(84, 10)
#         )







batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

print(net[1].gamma.view((-1,)), net[1].beta.view((-1,)))

print("*"*50)

相关文章:

  • 基于.NET Compact Framework的应用程序和库汇总
  • 《动手学深度学习》(PyTorch版)代码注释 - 28 【Batch_normalization_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 29 【ResNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持
  • 《动手学深度学习》(PyTorch版)代码注释 - 30 【DenseNet_Learning】
  • WSS3.0 和 MOSS2007 对移动设备浏览器的支持(续)
  • 《动手学深度学习》(PyTorch版)代码注释 - 31 【Language_model_data_set】
  • Windows Mobile 6 SDK 中的 Device Emulator 2.0
  • 《动手学深度学习》(PyTorch版)代码注释 - 32 【RNN_with_zero】
  • Windows Mobile 6 SDK 中的 Cellular Emulator
  • Windows Mobile 6 SDK 中的 GPS 工具
  • 《动手学深度学习》(PyTorch版)代码注释 - 33 【RNN_with_simple_way】
  • 《动手学深度学习》(PyTorch版)代码注释 - 34 【GRU_with_zero】
  • MEDC2007北京游记 - WindowsMobile Ophone
  • 《动手学深度学习》(PyTorch版)代码注释 - 35 【GRU_with_simple_way】
  • [原]深入对比数据科学工具箱:Python和R 非结构化数据的结构化
  • cookie和session
  • co模块的前端实现
  • es6
  • ES6系列(二)变量的解构赋值
  • IIS 10 PHP CGI 设置 PHP_INI_SCAN_DIR
  • mongo索引构建
  • node入门
  • Python语法速览与机器学习开发环境搭建
  • Quartz初级教程
  • Twitter赢在开放,三年创造奇迹
  • Web Storage相关
  • 阿里云应用高可用服务公测发布
  • 分享自己折腾多时的一套 vue 组件 --we-vue
  • 高度不固定时垂直居中
  • 回顾2016
  • 开源地图数据可视化库——mapnik
  • 可能是历史上最全的CC0版权可以免费商用的图片网站
  • 聊聊flink的TableFactory
  • 前端工程化(Gulp、Webpack)-webpack
  • 前端设计模式
  • 小程序button引导用户授权
  • 优化 Vue 项目编译文件大小
  • gunicorn工作原理
  • Salesforce和SAP Netweaver里数据库表的元数据设计
  • ​一些不规范的GTID使用场景
  • #QT(串口助手-界面)
  • (6)STL算法之转换
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (NO.00004)iOS实现打砖块游戏(九):游戏中小球与反弹棒的碰撞
  • (附源码)spring boot儿童教育管理系统 毕业设计 281442
  • (附源码)spring boot建达集团公司平台 毕业设计 141538
  • (附源码)springboot 基于HTML5的个人网页的网站设计与实现 毕业设计 031623
  • (五) 一起学 Unix 环境高级编程 (APUE) 之 进程环境
  • (一)Mocha源码阅读: 项目结构及命令行启动
  • (原創) 如何讓IE7按第二次Ctrl + Tab時,回到原來的索引標籤? (Web) (IE) (OS) (Windows)...
  • (转)ObjectiveC 深浅拷贝学习
  • *** 2003
  • .NET Core Web APi类库如何内嵌运行?
  • .NET Core 中的路径问题