当前位置: 首页 > news >正文

30分钟吃掉pytorch中的各种归一化层

一,归一化层概述

归一化技术对于训练深度神经网络非常重要。

它们的主要作用是让模型的中间层的输入分布稳定在合适的范围,加快模型训练过程的收敛速度,并提升模型对输入变动的抗干扰能力。

各种归一化层使用的公式都是一样的,如下所示:

其中的 和 是可学习的参数。注意到,当 恰好取标准差,恰好取均值时,归一化层刚好是一个恒等变换。这就能够保证归一化层在最坏的情况下,可学习为一个恒等变换,不会给模型带来负面影响。

本文节选自 eat pytorch in 20 days 的 《5-2,模型层》前半部分。公众号后台回复关键词:pytorch,获取本文全部源代码和吃货本货BiliBili视频讲解哦🍉🍉

二,BatchNorm和LayerNorm的差别?

pytorch中内置的归一化层包括 nn.BatchNorm2d(1d), nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d 等等。

其中最常用的是BatchNorm2d(1d)和LayerNorm。

不同的归一化层的差异主要是计算均值和方差时候参与计算的数据不一样。

BatchNorm是在样本维度进行归一化(一个批次内不同样本的相同特征计算均值和方差),而LayerNorm是在特征维度进行归一化(同一个样本的不同特征计算均值和方差)。

BatchNorm常用于结构化数据(BatchNorm1D)和图片数据(BatchNorm2D),LayerNorm常用于文本数据。

import torch 
from torch import nn 

batch_size, channel, height, width = 32, 16, 128, 128

tensor = torch.arange(0,32*16*128*128).view(32,16,128,128).float() 

bn = nn.BatchNorm2d(num_features=channel,affine=False)
bn_out = bn(tensor)


channel_mean = torch.mean(bn_out[:,0,:,:]) 
channel_std = torch.std(bn_out[:,0,:,:])
print("channel mean:",channel_mean.item())
print("channel std:",channel_std.item())
channel mean: 1.1920928955078125e-07
channel std: 1.0000009536743164
import torch 
from torch import nn 

batch_size, sequence, features = 32, 100, 2048
tensor = torch.arange(0,32*100*2048).view(32,100,2048).float() 

ln = nn.LayerNorm(normalized_shape=[features],
                  elementwise_affine = False)

ln_out = ln(tensor)

token_mean = torch.mean(ln_out[0,0,:]) 
token_std = torch.std(ln_out[0,0,:])
print("token_mean:",token_mean.item())
print("token_mean:",token_std.item())
token_mean: -5.8673322200775146e-08
token_mean: 1.0002442598342896

三,为什么不同类型的数据要使用不同的归一化层?

  • 结构化数据通常使用BatchNorm1D归一化 【结构化数据的主要区分度来自每个样本特征在全体样本中的排序,将全部样本的某个特征都进行相同的放大缩小平移操作,样本间的区分度基本保持不变,所以结构化数据可以做BatchNorm,但LayerNorm会打乱全体样本根据某个特征的排序关系,引起区分度下降】

29fbfdba82fab667c83e268c7006dbe0.jpeg
  • 图片数据最常用的是BatchNorm2D,有些场景也会用LayerNorm,GroupNorm或者InstanceNorm【图片数据的主要区分度来自图片中的纹理结构,所以图片数据的归一化一定要在图片的宽高方向上操作以保持纹理结构,此外在Batch维度上操作还能够引入少许的正则化,对提升精度有进一步的帮助。】

2e766719f19fbfbed2e54532039a9cd1.jpeg
  • 文本数据一般都使用LayerNorm归一化 【文本数据的主要区分度来自于词向量(Embedding向量)的方向,所以文本数据的归一化一定要在 特征(通道)维度上操作 以保持 词向量方向不变。此外文本数据还有一个重要的特点是不同样本的序列长度往往不一样,所以不可以在Sequence和Batch维度上做归一化,否则将不可避免地将padding位置对应的向量和普通的词向混合起来进行归一,这会让变成padding对应的向量变成非零向量,从而对梯度产生不合预期的影响。即使做特殊处理让padding位置的向量不参与归一化保持为0值,由于样本间序列长度的差异,也会造成参与运算的归一的数据量在不同样本和批次间剧烈波动,不利于模型的稳定学习。】

fde868160d27d7cce84eac92b71c6bde.jpeg
  • 有论文提出了一种可自适应学习的归一化:SwitchableNorm,可应用于各种场景且有一定的效果提升。【SwitchableNorm是将BN、LN、IN结合,赋予权重,让网络自己去学习归一化层应该使用什么方法。参考论文:https://arxiv.org/pdf/1806.10779.pdf】

四,BatchNorm补充问题

(1)BatchNorm放在激活函数前还是激活函数后?

原始论文认为将BatchNorm放在激活函数前效果较好,后面的研究一般认为将BatchNorm放在激活函数之后更好。

(2)BatchNorm在训练过程和推理过程的逻辑是否一样?

不一样!训练过程BatchNorm的均值和方差和根据mini-batch中的数据估计的,而推理过程中BatchNorm的均值和方差是用的训练过程中的全体样本估计的。因此预测过程是稳定的,相同的样本不会因为所在批次的差异得到不同的结果,但训练过程中则会受到批次中其他样本的影响所以有正则化效果。

(3)BatchNorm的精度效果与batch_size大小有何关系?

如果受到GPU内存限制,不得不使用很小的batch_size,训练阶段时使用的mini-batch上的均值和方差的估计和预测阶段时使用的全体样本上的均值和方差的估计差异可能会较大,效果会变差。这时候,可以尝试LayerNorm或者GroupNorm等归一化方法。

本文节选自 eat pytorch in 20 days 的 《5-2,模型层》前半部分。公众号后台回复关键词:pytorch,获取本文全部源代码和吃货本货BiliBili视频讲解哦🍉🍉

b9f92ee9d8125d818dd01f541b8062f2.png

相关文章:

  • Three.js 这样写就有阴影效果啦
  • Cravatar头像
  • Python-爬虫 (BS4数据解析)
  • java基于ssm+vue+elementui的多用户博客管理系统
  • java毕业设计网站swing mysql实现的仓库商品管理系统[包运行成功]
  • java毕业设计论文题目基于SSM实现的小区物业管理系统[包运行成功]
  • “蔚来杯“2022牛客暑期多校训练营10 EF题解
  • 人工智能科学计算库—Numpy教程
  • i.MX6ULL应用移植 | 基于ubuntu base 16.04搭建python3.9+pip3环境
  • vim文本编辑器
  • 网课搜题接口
  • 网课查题API接口(免费)
  • 超分辨率重建DRRN
  • MacOS 环境编译 JVM 源码
  • Linux内核互斥技术1
  • [译]如何构建服务器端web组件,为何要构建?
  • ES6系列(二)变量的解构赋值
  • ES6系统学习----从Apollo Client看解构赋值
  • HTTP请求重发
  • IOS评论框不贴底(ios12新bug)
  • LeetCode29.两数相除 JavaScript
  • Linux快速复制或删除大量小文件
  • php ci框架整合银盛支付
  • windows下mongoDB的环境配置
  • 如何实现 font-size 的响应式
  • 使用Tinker来调试Laravel应用程序的数据以及使用Tinker一些总结
  • 手机app有了短信验证码还有没必要有图片验证码?
  • 数组的操作
  • 06-01 点餐小程序前台界面搭建
  • ​ 全球云科技基础设施:亚马逊云科技的海外服务器网络如何演进
  • ​水经微图Web1.5.0版即将上线
  • ​香农与信息论三大定律
  • #100天计划# 2013年9月29日
  • #android不同版本废弃api,新api。
  • #define,static,const,三种常量的区别
  • #vue3 实现前端下载excel文件模板功能
  • ${ }的特别功能
  • %check_box% in rails :coditions={:has_many , :through}
  • (c语言)strcpy函数用法
  • (Matlab)基于蝙蝠算法实现电力系统经济调度
  • (八十八)VFL语言初步 - 实现布局
  • (第27天)Oracle 数据泵转换分区表
  • (二)JAVA使用POI操作excel
  • (分布式缓存)Redis持久化
  • (力扣记录)1448. 统计二叉树中好节点的数目
  • (排序详解之 堆排序)
  • (三)Honghu Cloud云架构一定时调度平台
  • (删)Java线程同步实现一:synchronzied和wait()/notify()
  • (十八)用JAVA编写MP3解码器——迷你播放器
  • (转)从零实现3D图像引擎:(8)参数化直线与3D平面函数库
  • .NET Compact Framework 多线程环境下的UI异步刷新
  • .net 生成二级域名
  • .NET 实现 NTFS 文件系统的硬链接 mklink /J(Junction)
  • .NET 中各种混淆(Obfuscation)的含义、原理、实际效果和不同级别的差异(使用 SmartAssembly)
  • .net 重复调用webservice_Java RMI 远程调用详解,优劣势说明