当前位置: 首页 > news >正文

抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import timeclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linears = nn.Sequential(nn.Linear(2, 20),nn.Linear(20, 20),nn.Dropout(0.1),nn.Linear(20, 20),nn.Linear(20, 20),nn.Linear(20, 1),)def forward(self, x):_ = self.linears(x)return _lr = 0.01
iteration = 1000x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()start_time = time.time()
writer = SummaryWriter(comment='_随机失活')for iter in range(iteration):y_pred = model(x)loss = loss_function(y, y_pred.squeeze())loss.backward()for name, layer in model.named_parameters():writer.add_histogram(name + '_grad', layer.grad, iter)writer.add_histogram(name + '_data', layer, iter)writer.add_scalar('loss', loss, iter)optimizer.step()optimizer.zero_grad()if iter % 50 == 0:print("iter: ", iter)print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

相关文章:

  • C#学习-8课时
  • 2023.11.29 深度学习框架理解
  • 2023年c语言程序设计大赛
  • springmvc(基础学习整合)
  • 性能优化中使用Profiler进行页面卡顿的排查及解决方式
  • Android——资源IDnonFinalResIds和“Attribute value must be constant”错误
  • ELFK集群部署(Filebeat+ELK) 本地收集nginx日志 远程收集多个日志
  • 异常 Exception 练习题 (未完成)
  • 合并PDF出现OOM异常
  • Oracle SQL优化
  • 【小白进阶】Linux 调试大法——gdb
  • 软件测评中心▏软件集成测试和功能测试之间的区别和联系简析
  • 02、Tensorflow实现手写数字识别(数字0-9)
  • 在线文库系统 转码功能源代码展示 支持文档在线预览查阅功能
  • Linux “grep“ 命令
  • IE9 : DOM Exception: INVALID_CHARACTER_ERR (5)
  • 「译」Node.js Streams 基础
  • 【399天】跃迁之路——程序员高效学习方法论探索系列(实验阶段156-2018.03.11)...
  • 【知识碎片】第三方登录弹窗效果
  • android 一些 utils
  • css的样式优先级
  • open-falcon 开发笔记(一):从零开始搭建虚拟服务器和监测环境
  • Spring Cloud Feign的两种使用姿势
  • TiDB 源码阅读系列文章(十)Chunk 和执行框架简介
  • VuePress 静态网站生成
  • Vue学习第二天
  • 关于使用markdown的方法(引自CSDN教程)
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 为物联网而生:高性能时间序列数据库HiTSDB商业化首发!
  • 学习笔记DL002:AI、机器学习、表示学习、深度学习,第一次大衰退
  • 东超科技获得千万级Pre-A轮融资,投资方为中科创星 ...
  • # Python csv、xlsx、json、二进制(MP3) 文件读写基本使用
  • # 再次尝试 连接失败_无线WiFi无法连接到网络怎么办【解决方法】
  • (C语言)深入理解指针2之野指针与传值与传址与assert断言
  • (delphi11最新学习资料) Object Pascal 学习笔记---第5章第5节(delphi中的指针)
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • (附源码)ssm失物招领系统 毕业设计 182317
  • (深度全面解析)ChatGPT的重大更新给创业者带来了哪些红利机会
  • (十八)devops持续集成开发——使用docker安装部署jenkins流水线服务
  • (一)基于IDEA的JAVA基础12
  • ****** 二十三 ******、软设笔记【数据库】-数据操作-常用关系操作、关系运算
  • ***监测系统的构建(chkrootkit )
  • ./configure,make,make install的作用(转)
  • .apk文件,IIS不支持下载解决
  • .babyk勒索病毒解析:恶意更新如何威胁您的数据安全
  • .NET CLR Hosting 简介
  • .NET CORE 第一节 创建基本的 asp.net core
  • .net 逐行读取大文本文件_如何使用 Java 灵活读取 Excel 内容 ?
  • .Net6支持的操作系统版本(.net8已来,你还在用.netframework4.5吗)
  • .NET企业级应用架构设计系列之技术选型
  • .sh
  • ;号自动换行
  • ??myeclipse+tomcat
  • @JsonFormat与@DateTimeFormat注解的使用
  • [ vulhub漏洞复现篇 ] Grafana任意文件读取漏洞CVE-2021-43798