当前位置: 首页 > news >正文

pytorch中对nn.BatchNorm2d()函数的理解

pytorch中对BatchNorm2d函数的理解

  • 简介
  • 计算
  • 3. Pytorch的nn.BatchNorm2d()函数
  • 4 代码示例

简介

机器学习中,进行模型训练之前,需对数据做归一化处理,使其分布一致。在深度神经网络训练过程中,通常一次训练是一个batch,而非全体数据。每个batch具有不同的分布产生了internal covarivate shift问题——在训练过程中,数据分布会发生变化,对下一层网络的学习带来困难。Batch Normalization强行将数据拉回到均值为0,方差为1的正太分布上,一方面使得数据分布一致,另一方面避免梯度消失。

计算

如图所示:
在这里插入图片描述
在这里插入图片描述

3. Pytorch的nn.BatchNorm2d()函数

其主要需要输入4个参数:
(1)num_features:输入数据的shape一般为[batch_size, channel, height, width], num_features为其中的channel;
(2)eps: 分母中添加的一个值,目的是为了计算的稳定性,默认:1e-5;
(3)momentum: 一个用于运行过程中均值和方差的一个估计参数,默认值为0.1.
在这里插入图片描述
(4)affine:当设为true时,给定可以学习的系数矩阵 γ \gamma γ β \beta β

4 代码示例

import torchdata = torch.ones(size=(2, 2, 3, 4))
data[0][0][0][0] = 25
print("data = ", data)print("\n")print("=========================使用封装的BatchNorm2d()计算================================")
BN = torch.nn.BatchNorm2d(num_features=2, eps=0, momentum=0)
BN_data = BN(data)
print("BN_data = ", BN_data)print("\n")print("=========================自行计算================================")
x = torch.cat((data[0][0], data[1][0]), dim=1)      # 1.将同一通道进行拼接(即把同一通道当作一个整体)
x_mean = torch.Tensor.mean(x)                       # 2.计算同一通道所有制的均值(即拼接后的均值)
x_var = torch.Tensor.var(x, False)                  # 3.计算同一通道所有制的方差(即拼接后的方差)# 4.使用第一个数按照公式来求BatchNorm后的值
bn_first = ((data[0][0][0][0] - x_mean) / ( torch.pow(x_var, 0.5))) * BN.weight[0] + BN.bias[0]
print("bn_first = ", bn_first)

在这里插入图片描述
在这里插入图片描述

相关文章:

  • 时序预测 | MATLAB实现WOA-CNN-GRU-Attention时间序列预测(SE注意力机制)
  • 部署ruoyi-vue-plus和ruoyi-app
  • Spring底层原理学习笔记--第五讲--(常见工厂后处理器与工厂后处理器模拟实现)
  • Sass 最基础的语法
  • Maven-依赖管理机制
  • 【大数据分布并行处理】单元测试(三)
  • CMOS介绍
  • [HXPCTF 2021]includer‘s revenge
  • MYSQL---基础篇
  • 4.HTML网页开发的工具
  • Clickhouse学习笔记(11)—— 数据一致性
  • ELK分布式日志
  • TypeScript: 判断两个数组的内容是否相等
  • 解决游戏找不到x3daudio1_7.dll文件的5个方法,快速修复dll问题
  • Ubuntu 20.04编译Chrome浏览器
  • [js高手之路]搞清楚面向对象,必须要理解对象在创建过程中的内存表示
  • “寒冬”下的金三银四跳槽季来了,帮你客观分析一下局面
  • Android 控件背景颜色处理
  • CentOS从零开始部署Nodejs项目
  • ECMAScript6(0):ES6简明参考手册
  • Linux Process Manage
  • mysql常用命令汇总
  • php面试题 汇集2
  • Python十分钟制作属于你自己的个性logo
  • Work@Alibaba 阿里巴巴的企业应用构建之路
  • 安卓应用性能调试和优化经验分享
  • 开发了一款写作软件(OSX,Windows),附带Electron开发指南
  • 那些年我们用过的显示性能指标
  • 微信开源mars源码分析1—上层samples分析
  • const的用法,特别是用在函数前面与后面的区别
  • ionic异常记录
  • ###C语言程序设计-----C语言学习(3)#
  • #define MODIFY_REG(REG, CLEARMASK, SETMASK)
  • #HarmonyOS:软件安装window和mac预览Hello World
  • #NOIP 2014# day.1 T3 飞扬的小鸟 bird
  • (9)STL算法之逆转旋转
  • (function(){})()的分步解析
  • (MATLAB)第五章-矩阵运算
  • (MIT博士)林达华老师-概率模型与计算机视觉”
  • (图)IntelliTrace Tools 跟踪云端程序
  • (原+转)Ubuntu16.04软件中心闪退及wifi消失
  • (转)http协议
  • (转载)Google Chrome调试JS
  • .libPaths()设置包加载目录
  • .net 4.0发布后不能正常显示图片问题
  • .NET Standard 支持的 .NET Framework 和 .NET Core
  • .NET/C# 检测电脑上安装的 .NET Framework 的版本
  • .Net的DataSet直接与SQL2005交互
  • .NET中 MVC 工厂模式浅析
  • .Net中的设计模式——Factory Method模式
  • /bin/rm: 参数列表过长"的解决办法
  • @Autowired和@Resource装配
  • @DependsOn:解析 Spring 中的依赖关系之艺术
  • [ vulhub漏洞复现篇 ] Apache Flink目录遍历(CVE-2020-17519)
  • [Android] Implementation vs API dependency