当前位置: 首页 > news >正文

GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:提升分类模型acc(一):BatchSize&LARS

在使用大的bs训练情况下,会对精度有一定程度的损失,本文探讨了训练的bs大小对精度的影响,同时探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。

论文链接:https://arxiv.org/abs/1708.03888

论文代码: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
知乎专栏: https://zhuanlan.zhihu.com/p/406882110

1 引言

如何提升业务分类模型的性能,一直是个难题,毕竟没有99.999%的性能都会带来一定程度的风险,所以很多时候我们只能通过控制阈值来调整准召以达到想要的效果。本系列主要探究哪些模型trick和数据的方法可以大幅度让你的分类性能更上一层楼,不过要注意一点的是,tirck不一定是适用于不同的数据场景的,但是数据处理方法是普适的。本篇文章主要是对于大的bs下训练分类模型的情况,如果bs比较小的可以忽略,直接看最后的结论就好了(这个系列以后的文章讲述的方法是通用的,无论bs大小都可以用)。

2 实验配置

  • 模型:ResNet50

  • 数据:ImageNet1k

  • 环境:8xV100

3 BatchSize对精度的影响

我这里设计了4组对照实验,256, 1024, 2048和4096的batchsize,开了FP16也只能跑到了4096了。采用的是分布式训练,所以单张卡的bs就是bs = total_bs / ngpus_per_node。这里我没有使用跨卡bn,对于bs 64单卡来说理论上已经很大了,bn的作用是约束数据分布,64的bs已经可以表达一个分布的subset了,再大的bs还是同分布的,意义不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略这个问题。但是对于检测的任务,跨卡bn还是有价值的,毕竟输入的分辨率大,单卡的bs比较小,一般4,8,16,这时候统计更大的bn会对模型收敛更好。

很明显可以看出来,当bs增加到4k的时候,acc下降了将近0.8%个点,1k的时候,下降了0.2%个点,所以,通常我们用大的bs训练的时候,是没办法达到最优的精度的。个人建议,使用1k的bs和0.4的学习率最优。

4 LARS(Layer-wise Adaptive Rate Scaling)

4.1. 理论分析

由于bs的增加,在同样的epoch的情况下,会使网络的weights更新迭代的次数变少,所以需要对LR随着bs的增加而线性增加,但是这样会导致上面我们看到的问题,过大的lr会导致最终的收敛不稳定,精度有所下降。

LARS代码如下:

class LARC(object):def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):self.optim = optimizerself.trust_coefficient = trust_coefficientself.eps = epsself.clip = clipdef step(self):with torch.no_grad():weight_decays = []for group in self.optim.param_groups:# absorb weight decay control from optimizerweight_decay = group['weight_decay'] if 'weight_decay' in group else 0weight_decays.append(weight_decay)group['weight_decay'] = 0for p in group['params']:if p.grad is None:continueparam_norm = torch.norm(p.data)grad_norm = torch.norm(p.grad.data)if param_norm != 0 and grad_norm != 0:# calculate adaptive lr + weight decayadaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps)# clip learning rate for LARCif self.clip:# calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`adaptive_lr = min(adaptive_lr / group['lr'], 1)p.grad.data += weight_decay * p.datap.grad.data *= adaptive_lrself.optim.step()# return weight decay control to optimizerfor i, group in enumerate(self.optim.param_groups):group['weight_decay'] = weight_decays[i]

这里有一个超参数,trust_coefficient,也就是公式里面所提到的, 这个参数对精度的影响比较大,实验部分我们会给出结论。

4.2. 实验结论

可以很明显发现,使用了LARS,设置turst_confidence为1e-3的情况下,有着明显的掉点,设置为2e-2的时候,在1k和4k的情况下,有着明显的提升,但是2k的情况下有所下降。

LARS一定程度上可以提升精度,但是强依赖超参,还是需要细致的调参训练。

5 结论

  • 8卡进行分布式训练,使用1k的bs可以很好的平衡acc&speed。

  • LARS一定程度上可以提升精度,但是需要调参,做业务可以不用考虑,刷点的话要好好训练。

6 结束语

本文是提升分类模型acc系列的第一篇,后续会讲解一些通用的trick和数据处理的方法,敬请关注。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

相关文章:

  • 【Java每日一题】2.和数最大操作II-动态规划
  • 顶级域名和二级域名的区别
  • SOA设计的标准要求
  • SAP HCM HR_PAD_HIRE_EMPLOYEE 自定义信息类型字段保存问题
  • 标题:深入探索Linux中的`ausyscall`
  • SpringCloud整合OpenFeign实现微服务间的通信
  • Visual Studio Code 怎么恢复默认布置
  • 计算机组成结构—IO方式
  • SpringCache和SpringTask
  • 【ARM64 常见汇编指令学习 19.2 -- ARM64 地址加载指令 ADR 详细介绍】
  • 高防CDN是如何应对DDoS和CC攻击的
  • 堆排序-调整算法
  • wireshark 标记自己想要的数据包
  • C++ OpenCV 图像分类魔法:探索神奇的模型与代码
  • 【上篇】从 YOLOv1 到 YOLOv8 的 YOLO 物体检测模型历史
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 230. Kth Smallest Element in a BST
  • Java编程基础24——递归练习
  • Terraform入门 - 1. 安装Terraform
  • yii2中session跨域名的问题
  • 大数据与云计算学习:数据分析(二)
  • 京东美团研发面经
  • 入口文件开始,分析Vue源码实现
  • 山寨一个 Promise
  • 中文输入法与React文本输入框的问题与解决方案
  • 你学不懂C语言,是因为不懂编写C程序的7个步骤 ...
  • ​DB-Engines 12月数据库排名: PostgreSQL有望获得「2020年度数据库」荣誉?
  • ​LeetCode解法汇总307. 区域和检索 - 数组可修改
  • #!/usr/bin/python与#!/usr/bin/env python的区别
  • #laravel 通过手动安装依赖PHPExcel#
  • #如何使用 Qt 5.6 在 Android 上启用 NFC
  • ( )的作用是将计算机中的信息传送给用户,计算机应用基础 吉大15春学期《计算机应用基础》在线作业二及答案...
  • (4)事件处理——(2)在页面加载的时候执行任务(Performing tasks on page load)...
  • (Redis使用系列) Springboot 实现Redis 同数据源动态切换db 八
  • (附源码)springboot家庭装修管理系统 毕业设计 613205
  • (十三)Flask之特殊装饰器详解
  • (转)memcache、redis缓存
  • (转)项目管理杂谈-我所期望的新人
  • (自适应手机端)行业协会机构网站模板
  • .net core + vue 搭建前后端分离的框架
  • .net core MVC 通过 Filters 过滤器拦截请求及响应内容
  • .net core 微服务_.NET Core 3.0中用 Code-First 方式创建 gRPC 服务与客户端
  • .NET Core6.0 MVC+layui+SqlSugar 简单增删改查
  • .NET Remoting学习笔记(三)信道
  • .net 调用海康SDK以及常见的坑解释
  • .Net程序猿乐Android发展---(10)框架布局FrameLayout
  • .NET导入Excel数据
  • .NET文档生成工具ADB使用图文教程
  • @GlobalLock注解作用与原理解析
  • [ 物联网 ]拟合模型解决传感器数据获取中数据与实际值的误差的补偿方法
  • [AIGC] 解题神器:Python中常用的高级数据结构
  • [Android] Binder 里的 Service 和 Interface 分别是什么
  • [BZOJ 3680]吊打XXX(模拟退火)
  • [C/C++]关于C++11中的std::move和std::forward
  • [C++] 小游戏 斗破苍穹 2.11.6 版本 zty出品