当前位置: 首页 > news >正文

Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

相关文章:

  • C语言之判断与循环语句知识点总结
  • 基于群居蜘蛛算法的无人机航迹规划
  • PostgreSQL 的 Replication Slot分析研究
  • 数据结构实验3
  • 树与二叉树(考研版)
  • 基于Kubesphere容器云平台物联网云平台Devops实践
  • RabbitMQ的交换机(原理及代码实现)
  • WPF:自定义按钮模板
  • python基础语法(十一)
  • 研发效能认证学员作品:快速进行持续集成应用实践丨IDCF
  • 使用pycharm远程连接到Linux服务器进行开发
  • ES6中数值扩展
  • 论文-分布式-并发控制-并发控制问题的解决方案
  • 【面试经典150 | 栈】最小栈
  • 2023辽宁省赛E
  • CSS进阶篇--用CSS开启硬件加速来提高网站性能
  • ERLANG 网工修炼笔记 ---- UDP
  • javascript 总结(常用工具类的封装)
  • Java程序员幽默爆笑锦集
  • Lucene解析 - 基本概念
  • Material Design
  • MySQL数据库运维之数据恢复
  • Python3爬取英雄联盟英雄皮肤大图
  • Rancher如何对接Ceph-RBD块存储
  • SpringCloud(第 039 篇)链接Mysql数据库,通过JpaRepository编写数据库访问
  • uva 10370 Above Average
  • webpack4 一点通
  • 从setTimeout-setInterval看JS线程
  • 官方解决所有 npm 全局安装权限问题
  • 将回调地狱按在地上摩擦的Promise
  • 蓝海存储开关机注意事项总结
  • 免费小说阅读小程序
  • 如何利用MongoDB打造TOP榜小程序
  • 如何用Ubuntu和Xen来设置Kubernetes?
  • 数据可视化之 Sankey 桑基图的实现
  • 原生Ajax
  • 蚂蚁金服CTO程立:真正的技术革命才刚刚开始
  • 直播平台建设千万不要忘记流媒体服务器的存在 ...
  • ​520就是要宠粉,你的心头书我买单
  • ​ssh免密码登录设置及问题总结
  • #pragma pack(1)
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (echarts)echarts使用时重新加载数据之前的数据存留在图上的问题
  • (Matlab)遗传算法优化的BP神经网络实现回归预测
  • (vue)el-checkbox 实现展示区分 label 和 value(展示值与选中获取值需不同)
  • (翻译)Quartz官方教程——第一课:Quartz入门
  • (强烈推荐)移动端音视频从零到上手(上)
  • (转) SpringBoot:使用spring-boot-devtools进行热部署以及不生效的问题解决
  • (转)拼包函数及网络封包的异常处理(含代码)
  • (轉貼)《OOD启思录》:61条面向对象设计的经验原则 (OO)
  • .【机器学习】隐马尔可夫模型(Hidden Markov Model,HMM)
  • .NET 4.0中使用内存映射文件实现进程通讯
  • .NET CORE 第一节 创建基本的 asp.net core
  • .Net Core与存储过程(一)
  • .net mvc 获取url中controller和action