当前位置: 首页 > news >正文

Batch Size 不同对evaluation performance的影响

目录

    • 问题描述
    • 如果是bug
    • batch size的设置问题
    • 尝试使用GroupNorm解决batchsize不同带来的问题
      • 归一化的分类
    • 参考文章

问题描述

深度学习网络训练时,使用较小的batch size训练网络后,如果换用较大的batch size进行evaluation,网络的预测能力会显著下降。如果evaluation的batch size和train的batch size大小相同时,则不会遇到此类问题。

PyTorch Forums – Performance highly degraded when eval() is activated in the test phase

如果是bug

  1. metric会根据batch_size的大小变化(但并不显著),metric按每个batch分别进行计算
  2. 缺失model.eval()指令:with torch.no_grad() 对dropout和batch normalization不起固定作用。
    1. nn.Dropout层参数不会固定
    2. nn.BatchNorm2d()
      1. PyTorch – BatchNorm2d BatchNorm2d函数中的参数track_running_stats:trainningtrack_running_statstrack_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
      2. trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_meanrunning_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。
  3. Dataloader中加入了随机处理,例如RandomCrop
  4. 没有固定随机种子

batch size的设置问题

如果batch size较小,会导致上述running_mean和running_var不准确。参考文章Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,当模型训练完成后,

x ^ = x − E [ x ] V a r [ x ] + ϵ \hat{x} = \frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} x^=Var[x]+ϵ xE[x]

其中, V a r [ x ] = m m − 1 E B [ σ B 2 ] Var[x]=\frac{m}{m-1}E_B[\sigma_B^2] Var[x]=m1mEB[σB2],the expectation is over training mini-batches of size m and σ B 2 \sigma_B^2 σB2 are their sample variances.

尝试使用GroupNorm解决batchsize不同带来的问题

归一化的分类

归一化的分类
LN 和 IN 在视觉识别上的成功率都是很有限的,对于训练序列模型(RNN/LSTM)或生成模型(GAN)很有效。

所以,在视觉领域,BN用的比较多,GN就是为了改善BN的不足而来的。

GN 把通道分为组,并计算每一组之内的均值和方差,以进行归一化。GN 的计算与批量大小无关,其精度也在各种批量大小下保持稳定。可以看到,GN和LN很像。

参考文章

pytorch 每次测试结果不同
Batch Normalization
深度学习中的组归一化(GroupNorm)

相关文章:

  • Stream toArray 好过collect
  • 常用知识点问答
  • 【Spring Boot】Java 持久层 API:JPA
  • 数据结构-第七章(B树和B+树)
  • 每日一道算法题 判断子序列
  • linux 环境报错:Peer reports incompatible or unsupported protocol version
  • sheng的学习笔记-hadoop,MapReduce,yarn,hdfs框架原理
  • 不使用AMap.DistrictSearch,通过poi数据绘制省市县区块
  • 巴西市场有哪些电商平台?巴西最畅销的产品有哪些?
  • 揭秘,PyArmor库让你的Python代码更安全
  • Linux 程序打包
  • 时尚品牌GOODBAI好人好事系列纪录片——Jupiter乐队的热血与梦想
  • ubuntu 18 虚拟机安装(3)安装mysql
  • Hadoop3:参数调优-核心参数NameNode内存配置、并发数配置、回收站配置
  • JAVA学习-练习试用Java实现“天际线问题”
  • Hibernate【inverse和cascade属性】知识要点
  • Mysql数据库的条件查询语句
  • node.js
  • PHP 小技巧
  • 前端性能优化--懒加载和预加载
  • 微服务框架lagom
  • 微信小程序实战练习(仿五洲到家微信版)
  • 小程序、APP Store 需要的 SSL 证书是个什么东西?
  • AI又要和人类“对打”,Deepmind宣布《星战Ⅱ》即将开始 ...
  • 交换综合实验一
  • ​LeetCode解法汇总2182. 构造限制重复的字符串
  • (1)Android开发优化---------UI优化
  • (2)MFC+openGL单文档框架glFrame
  • (3)Dubbo启动时qos-server can not bind localhost22222错误解决
  • (LNMP) How To Install Linux, nginx, MySQL, PHP
  • (Mirage系列之二)VMware Horizon Mirage的经典用户用例及真实案例分析
  • (笔记)Kotlin——Android封装ViewBinding之二 优化
  • (第61天)多租户架构(CDB/PDB)
  • (含笔试题)深度解析数据在内存中的存储
  • (力扣记录)235. 二叉搜索树的最近公共祖先
  • (四) Graphivz 颜色选择
  • (五) 一起学 Unix 环境高级编程 (APUE) 之 进程环境
  • (已解决)报错:Could not load the Qt platform plugin “xcb“
  • (转)winform之ListView
  • (轉貼) VS2005 快捷键 (初級) (.NET) (Visual Studio)
  • .FileZilla的使用和主动模式被动模式介绍
  • .gitignore文件—git忽略文件
  • .NET Core工程编译事件$(TargetDir)变量为空引发的思考
  • .net core开源商城系统源码,支持可视化布局小程序
  • .NET Core实战项目之CMS 第一章 入门篇-开篇及总体规划
  • .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法
  • .NetCore项目nginx发布
  • .Net环境下的缓存技术介绍
  • .NET建议使用的大小写命名原则
  • .NET中 MVC 工厂模式浅析
  • ?.的用法
  • [ 隧道技术 ] 反弹shell的集中常见方式(二)bash反弹shell
  • [\u4e00-\u9fa5] //匹配中文字符
  • [【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器
  • [2024] 十大免费电脑数据恢复软件——轻松恢复电脑上已删除文件