当前位置: 首页 > news >正文

1.6.丢弃法

丢弃法

动机:一个好的模型需要对输入数据的扰动足够健壮,丢弃法就是在层之间加入噪音。也可以在数据中使用噪音,等价与Tikhonov正则

无偏差的加入噪音

​ 对于数据 x x x,加入噪音后的 x ′ x' x的期望值是不变的, E [ x ′ ] = x E[x']=x E[x]=x

​ 则我们可以构造出一个简单的期望运算 E [ x ′ ] = p ⋅ 0 + ( 1 − p ) ⋅ x i 1 − p = x i E[x']=p\cdot 0+(1-p)\cdot\frac{x_i}{1-p} =x_i E[x]=p0+(1p)1pxi=xi

​ 那么可以这样处理元素:

在这里插入图片描述

​ 其中丢弃概率是超参数。常用在多层感知机的隐藏层输出上。

通常将丢弃法作用在隐藏全连接层的输出上:
h = σ ( W 1 x + b 1 ) h ′ = d r o p o u t ( h ) o = W 2 h ′ + b 2 y = s o f t m a x ( o ) h=\sigma(W_1x+b_1)\\ h' = dropout(h)\\ o = W_2h' +b_2\\ y=softmax(o) h=σ(W1x+b1)h=dropout(h)o=W2h+b2y=softmax(o)
在这里插入图片描述

​ 如图本来有5个隐藏层,但丢弃函数可能取到0,那么可能会直接消失,剩下的3个隐藏层变大。

​ 丢弃项其实是正则项,只在训练中使用,他们影响模型参数的更新。

​ 在推理过程中,丢弃法直接返回输入 h = d r o p o u t ( h ) h = dropout(h) h=dropout(h),也可以保证确定性的输出

​ 实际上丢弃法的实质是每次训练中使用一个神经网络的子集来做训练, 则多次训练后得到的是多个神经网络的平均,效果自然要好一些。

​ 现在普遍将丢弃项认为是正则项,效果和正则项基本相同。

​ 在输入数据比较简单,但神经网络比较大时,dropout可能会比较有用。

​ dropout1=0.2,dropout2=0.5:

在这里插入图片描述

​ dropout1=0.dropout2=0"

在这里插入图片描述

​ 效果出乎意料的好,说明这个模型本身就没过拟合,这时候使用dropout可能效果不好。一般的小技巧是模型设大一点,然后使用dropout来进行调整。

代码实现

import torch
from torch import nn
from d2l import torch as d2ldef dropout_layer(X, dropout):assert 0 <= dropout <= 1  # 丢弃概率必须在0到1之间if dropout == 1:return torch.zeros_like(X)  # 全0则全部丢弃if dropout == 0:return X  # 0则不丢弃mask = (torch.rand(X.shape) > dropout).float()  # rand生成0到1之间的随机数return mask * X / (1.0 - dropout)num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256# dropout1, dropout2 = 0.2, 0.5
dropout1, dropout2 = 0., 0.# 定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元,有三个线性层,最后一个是输出层
class Net(nn.Module):def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training=True):super(Net, self).__init__()self.num_inputs = num_inputsself.training = is_trainingself.lin1 = nn.Linear(num_inputs, num_hiddens1)self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)self.lin3 = nn.Linear(num_hiddens2, num_outputs)self.relu = nn.ReLU()def forward(self, X):H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))# 只有在训练模型时才使用dropoutif self.training == True:# 在第一个全连接层之后添加一个dropout层H1 = dropout_layer(H1, dropout1)H2 = self.relu(self.lin2(H1))if self.training == True:# 在第二个全连接层之后添加一个dropout层H2 = dropout_layer(H2, dropout2)out = self.lin3(H2)return outnet = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()'''简洁实现'''net = nn.Sequential(nn.Flatten(),nn.Linear(784, 256),nn.ReLU(),# 在第一个全连接层之后添加一个dropout层nn.Dropout(dropout1),nn.Linear(256, 256),nn.ReLU(),# 在第二个全连接层之后添加一个dropout层nn.Dropout(dropout2),nn.Linear(256, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

相关文章:

  • 论文复现:Predictive Control of Networked Multiagent Systems via Cloud Computing
  • x264 编码器 CAVLC 熵编码源码分析
  • Alpine Linux 轻量级Linux 适合于 docker 容器镜像
  • 浏览器缓存:强缓存与协商缓存实现原理有哪些?
  • HTTPS请求头缺少HttpOnly和Secure属性解决方案
  • 微服务实战系列之玩转Docker(二)
  • redis基本类型和订阅
  • 数据结构之初始二叉树(2)
  • docker网络互联
  • 机器学习-20-基于交互式web应用框架streamlit的基础使用教程
  • 企业如何查看员工的上网时长和记录?如何查看公司局域网员工电脑的上网记录
  • uniapp 开发 App 对接官方更新功能
  • 【Android】基础—基本布局
  • 校验el-table中表单项
  • Flink实时开发添加水印的案例分析
  • [iOS]Core Data浅析一 -- 启用Core Data
  • 【Redis学习笔记】2018-06-28 redis命令源码学习1
  • 【面试系列】之二:关于js原型
  • CentOS6 编译安装 redis-3.2.3
  • in typeof instanceof ===这些运算符有什么作用
  • input实现文字超出省略号功能
  • iOS动画编程-View动画[ 1 ] 基础View动画
  • macOS 中 shell 创建文件夹及文件并 VS Code 打开
  • Python进阶细节
  • Redis提升并发能力 | 从0开始构建SpringCloud微服务(2)
  • select2 取值 遍历 设置默认值
  • sessionStorage和localStorage
  • vue.js框架原理浅析
  • Wamp集成环境 添加PHP的新版本
  • 阿里研究院入选中国企业智库系统影响力榜
  • 从零开始的无人驾驶 1
  • 从零开始学习部署
  • 力扣(LeetCode)965
  • 嵌入式文件系统
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • 异步
  • media数据库操作,可以进行增删改查,实现回收站,隐私照片功能 SharedPreferences存储地址:
  • Java性能优化之JVM GC(垃圾回收机制)
  • 数据可视化之下发图实践
  • #FPGA(基础知识)
  • #QT(智能家居界面-界面切换)
  • #我与Java虚拟机的故事#连载18:JAVA成长之路
  • (2024,Flag-DiT,文本引导的多模态生成,SR,统一的标记化,RoPE、RMSNorm 和流匹配)Lumina-T2X
  • (Matalb时序预测)PSO-BP粒子群算法优化BP神经网络的多维时序回归预测
  • (Pytorch框架)神经网络输出维度调试,做出我们自己的网络来!!(详细教程~)
  • (差分)胡桃爱原石
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (分享)自己整理的一些简单awk实用语句
  • (附源码)spring boot儿童教育管理系统 毕业设计 281442
  • (附源码)spring boot基于小程序酒店疫情系统 毕业设计 091931
  • (论文阅读30/100)Convolutional Pose Machines
  • (每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理 第13章 项目资源管理(七)
  • (十二)springboot实战——SSE服务推送事件案例实现
  • (算法)求1到1亿间的质数或素数