当前位置: 首页 > news >正文

Layer Normalization(LN) 层标准化

CNN用BN, RNN用LN

BN又叫纵向规范化,LN又叫横向规范化

LN也是因为Transformer才为人们所熟知的

BN并不适用于RNN等动态网络和batchsize较小的时候效果不好。Layer Normalization(LN)的提出有效的解决BN的这两个问题。LN和BN不同点是归一化的维度是互相垂直的

叫Layer norm,其实是对单个样本做的,对batch的每一个样本做, 如果一个batch有n个feature,他就做n次。就像BN有c个channel时做c次一样。

之所以叫layer norm是因为三维的时候,一个样本就是一层

时序特征并不能用Batch Normalization,因为一个batch中的序列有长有短。就像图中画的,蓝线是BN取的,橙线是LN取的

此外,BN 的一个缺点是需要较大的 batchsize 才能合理估训练数据的均值和方差,这导致内存很可能不够用,同时它也很难应用在训练数据长度不同的 RNN 模型上。

LN需要注意的地方

  • 不再有running_mean和running_var
  • gamma和beta为逐元素的

LayerNorm中不会像BatchNorm那样跟踪统计全局的均值方差,因此train()和eval()对LayerNorm没有影响

其实在eval模式下,只有BatchNorm会屏蔽,其他Norm函数不会屏蔽

LN在PyTorch中的实现

torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
  • normalized_shape:(int/list/torch.Size)该层的特征维度,即要被标准化的维度。
  • eps:分母修正项。
  • elementwise_affine:是否需要affine transform,这里也提醒你是逐元素的仿射变换。

对于image

import torch
from torch import nn
N, C, H, W = 20, 5, 10, 10
input = torch.randn(N, C, H, W)
layer_norm = nn.LayerNorm([C, H, W]) # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
output = layer_norm(input)

对于NLP

import torch
from torch import nn
batch, sentence_length, embedding_dim = 20, 5, 10
input = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
output = layer_norm(input)

手动实现

import torch
import torch.nn as nn

ln_layer = nn.LayerNorm(normalized_shape=[3,20,20], elementwise_affine=False)
input = torch.randn(8, 3, 20, 20)
ln_outputs = ln_layer(input)

mean = torch.mean(input, dim=[1,2,3], keepdim=True)
var = torch.var(input, dim=[1,2,3], keepdim=True, unbiased=False)
output = (input-mean) / (torch.sqrt(var) + ln_layer.eps)

assert torch.allclose(ln_outputs, output, rtol=1e-4)

偶尔还会报错,多运行几次。

这说明我们自己算的,虽然计算方法是对的,在数值上还和官方实现的有1e-6,1e-7等的差异,因为官方的实现可能是C语言底层,然后开了cuda.benmark等优化

为什么用torch.var和torch.sqrt(torch.var) 而不是直接用torch.std, 因为官网的公式是这么写的

    直接用std也可以

【关于 BatchNorm vs LayerNorm】那些你不知道的事-技术圈

相关文章:

  • TF_CPP_MIN_LOG_LEVEL
  • Python sys.argv
  • pytorch模型可复现设置(cudnn.benchmark 加速卷积运算 cudnn.deterministic)
  • Python sys.stdout
  • Python vars()函数
  • Python类的self
  • Python输出numpy array带逗号和不带逗号
  • center loss 中心损失
  • torch与lua的关系
  • Python类super(super().__init__())
  • 自回归模型(Autoregressive model)(auto)
  • Pytorch tensorboard与tensorboardX的区别
  • Pytorch中的BN和IN(affine仿射, track_running_stats)
  • Pytorch修改tensor值
  • Siamese Network(孪生网络/连体网络) (few-shot learning)
  • ----------
  • 深入了解以太坊
  • (ckeditor+ckfinder用法)Jquery,js获取ckeditor值
  • Android 架构优化~MVP 架构改造
  • gcc介绍及安装
  • GraphQL学习过程应该是这样的
  • java2019面试题北京
  • JAVA并发编程--1.基础概念
  • Java方法详解
  • java架构面试锦集:开源框架+并发+数据结构+大企必备面试题
  • mockjs让前端开发独立于后端
  • PHP 小技巧
  • python_bomb----数据类型总结
  • Quartz实现数据同步 | 从0开始构建SpringCloud微服务(3)
  • TCP拥塞控制
  • 对象引论
  • 基于阿里云移动推送的移动应用推送模式最佳实践
  • 利用阿里云 OSS 搭建私有 Docker 仓库
  • 浏览器缓存机制分析
  • 实现菜单下拉伸展折叠效果demo
  • 《天龙八部3D》Unity技术方案揭秘
  • ​Spring Boot 分片上传文件
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (JSP)EL——优化登录界面,获取对象,获取数据
  • (搬运以学习)flask 上下文的实现
  • (原創) 如何將struct塞進vector? (C/C++) (STL)
  • (转)http-server应用
  • (转)LINQ之路
  • (转)负载均衡,回话保持,cookie
  • .locked1、locked勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复
  • .NET CORE 2.0发布后没有 VIEWS视图页面文件
  • .NET I/O 学习笔记:对文件和目录进行解压缩操作
  • .NET 依赖注入和配置系统
  • .NET构架之我见
  • @Autowired @Resource @Qualifier的区别
  • @Autowired多个相同类型bean装配问题
  • @RequestBody与@ModelAttribute
  • [20190401]关于semtimedop函数调用.txt
  • [ABC294Ex] K-Coloring
  • [android] 切换界面的通用处理