当前位置: 首页 > news >正文

深度学习调参基础

文章目录

  • 深度学习调参基础
    • 1.需要调节的超参数有哪些?
    • 2.什么时候需要调参?
    • 3.如何调参?
      • 3.1过拟合情况调参
      • 3.2欠拟合情况调参
      • 3.3收敛但震荡情况调参
      • 3.4不收敛情况调参
    • 4.调参示例
    • 参考

深度学习调参基础

1.需要调节的超参数有哪些?

  • 和网络结构相关的参数:神经网络的网络层数、不同层的类别和搭建顺序、隐藏层神经元的参数设置、LOSS层的选择、正则化参数
  • 和训练过程相关的参数:网络权重初始化方法、学习率使用策略、迭代次数、Batch的大小、输入数据相关

2.什么时候需要调参?

QQ_1720585309283

  • 恰好拟合(一般不需要调参)
  • 过拟合
  • 欠拟合
  • 收敛但震荡
  • 不收敛

3.如何调参?

3.1过拟合情况调参

  • 增加数据量。收集更多的训练数据,或者通过数据增强(Data Augmentation)的方法来增加数据量。
  • 使用正则化技术。L1或L2正则化、Dropout、早停(Early Stopping)
  • 减少模型复杂度。减少模型的参数数量,例如减少层数或者每层的神经元数量。
  • 使用交叉验证。通过交叉验证来评估模型性能,选择最佳的超参数。
  • 调整学习率。使用学习率衰减(Learning Rate Decay)来逐步减小学习率,从而让模型在训练后期更稳定。
  • 调整批量大小。增加或减少批量大小(Batch Size),不同的数据集和模型可能需要不同的批量大小来达到最优效果。

3.2欠拟合情况调参

  • 增加模型复杂度。增加神经网络的层数或每层的神经元数量。
  • 训练更长时间。增加训练轮数(Epochs)
  • 调整学习率。适当增加学习率,以加快模型的收敛速度。
  • 减少正则化。减少或移除正则化项(例如L2正则化),以允许模型在训练数据上拟合得更好。降低或移除Dropout层,以减少训练过程中神经元的随机丢弃。
  • 优化数据处理。确保数据预处理和归一化步骤没有问题,使数据分布适合模型训练。
  • 使用更大的批量大小。

3.3收敛但震荡情况调参

  • 降低学习率。学习率过高可能导致模型在收敛过程中震荡。适当降低学习率,模型的更新步骤会变小,从而有助于稳定收敛。
  • 使用学习率调度器(如学习率衰减、余弦退火等)来动态调整学习率。
  • 增加批量大小。
  • 使用梯度裁剪。对梯度进行裁剪(Gradient Clipping),将梯度的最大范数限制在一个固定值以内,防止梯度爆炸和震荡。
  • 增加正则化。增加L2正则化或者增加Dropout的比例,可以使模型的权重更新更为平滑,从而减少震荡。
  • 确保数据没有问题,数据预处理和归一化步骤正确。
  • 适当简化模型架构,减少过深或过宽的网络结构。

3.4不收敛情况调参

  • 增加或降低学习率。学习率过高可能导致模型参数更新过大,无法收敛。学习率过低,模型可能收敛得太慢或者陷入局部极小值。
  • 增加或降低模型复杂度。模型过于简单,无法拟合数据。模型过于复杂,难以训练。
  • 改变激活函数。使用不同的激活函数,如ReLU、Leaky ReLU、ELU、Swish等。
  • 调整优化器。尝试不同的优化器,如Adam、RMSprop、SGD with Momentum等。
  • 增加批量大小。
  • 增加训练次数。
  • 使用合适的权重初始化方法,如He初始化或Xavier初始化,确保模型在训练初期不会因为不合理的权重导致无法收敛。

4.调参示例

补充知识:

len(dataloader):返回batch的数量,即一个数据集总共有多少个 batch。

len(dataloader.dataset):返回数据集中样本的数量,即 dataset 的长度。

"""
coding:utf-8
* @Author:FHTT-Tian
* @name:Adjust Parameter.py
* @Time:2024/7/10 星期三 16:39
* @Description: 调参示例代码,未调参之前
"""# 手写数字识别数据集mnist
import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader# 定义超参数
batch_size = 64
hidden_size = 64  # 神经元个数
learning_rate = 0.001
num_epochs = 10
input_size = 784  # 28*28
num_classes = 10# 定义存放loss的列表
train_loss_list = []
test_loss_list = []# 对图片进行预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,), )])# 下载数据集并预处理
trainset = dataset.MNIST(root="./MINST", train=True, download=True, transform=transform)
testset = dataset.MNIST(root="./MINST", train=False, download=True, transform=transform)# dataloader设置,加载数据集
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)# 构建网络
class Net(nn.Module):def __init__(self, input_size, hideen_size, num_classes):super().__init__()self.fc1 = nn.Linear(input_size, hideen_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hideen_size, num_classes)def forward(self, x):out = self.fc1(x.view(-1, input_size))out = self.relu(out)out = self.fc2(out)return out# 网络实例化
model = Net(input_size, hidden_size, num_classes)# 定义损失
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=learning_rate)# 训练
total_step = len(train_loader)  # len(dataloader):返回batch的数量,即一个数据集总共有多少个batch
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()train_loss_list.append(loss.item())  # train_loss_list 里面存的是每个batch对应的loss,loss.item() 返回的是当前批次的损失值。if (i + 1) % 100 == 0:print("Epoch [{}/{}], Step [{}/{}], Train Loss:{:.4f}".format(epoch + 1, num_epochs, i + 1, total_step,loss.item()))# 设置模型为评估模式model.eval()with torch.no_grad():test_loss = 0.0for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item() * images.size(0)  # 当前批次的损失乘以该批次的样本数,然后累加到 test_loss。为了计算整个测试数据集的总损失# len(dataloader.dataset):返回数据集中样本的数量,即dataset的长度。将累加的总损失除以测试数据集的总样本数,以获得平均损失。test_loss /= len(test_loader.dataset)test_loss_list.extend([test_loss] * total_step)  # 将平均测试损失值 test_loss 复制成长度为 total_step 的列表# 设置模型为训练模式model.train()print("Epoch [{}/{}], Test Loss:{:.4f}".format(epoch + 1, num_epochs, test_loss))# 绘制训练与测试的loss曲线
plt.plot(train_loss_list, label="Train Loss")
plt.plot(test_loss_list, label="Test Loss")
plt.title("Model Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

image-20240710210858471

  • 将SGD优化器改为Adam优化器查看模型效果:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

image-20240710212008854

  • batch_size修改为128,learning_rate修改为0.0001,查看模型的效果。

image-20240710212648519

总结:如果出现过拟合、欠拟合、收敛但震荡或不收敛的情况,尝试使用相应的调参方法进行调整,以期得到较好的结果。建议在一个好的骨干网络下修改模型,这样大多数参数已经调得很好,不需要我们调整。因此,主要任务是添加模块,如果效果不佳,则更换模块✌。

参考

  • 炼丹笔记六 : 调参技巧
  • 自动调参工具Optuna
  • Pytorch 数据加载—Dataset和DataLoader详解

😃😃😃

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • MySQL DDL
  • 使用Docker制作python项目镜像
  • DP(2) | Java | LeetCode 62, 63, 343, 96 做题总结(96 未完)
  • 7月11日学习打卡,数据结构栈
  • vue3项目打包的时候,怎么区别测试环境,和本地环境
  • 代码随想录算法训练营第9天
  • 142. 两个字符串的最小 ASCII 删除总和(卡码网周赛第二十五期(23年B站笔试真题))
  • java使用easypoi模版导出word详细步骤
  • 我被手机所伤,竟如此憔悴。
  • 假期笔记1:anaconda的安装与pycharm中的引用
  • Linux 程序卡死的特殊处理
  • 进度条提示-在python程序中使用避免我误以为挂掉了
  • 微服务的分布式事务解决方案
  • Linux 初识
  • 通过Arcgis从逐月平均气温数据中提取并计算年平均气温
  • ES6指北【2】—— 箭头函数
  • canvas 绘制双线技巧
  • const let
  • Electron入门介绍
  • ES6 ...操作符
  • git 常用命令
  • Java-详解HashMap
  • Js实现点击查看全文(类似今日头条、知乎日报效果)
  • Making An Indicator With Pure CSS
  • OpenStack安装流程(juno版)- 添加网络服务(neutron)- controller节点
  • Python - 闭包Closure
  • Python打包系统简单入门
  • SpiderData 2019年2月13日 DApp数据排行榜
  • Synchronized 关键字使用、底层原理、JDK1.6 之后的底层优化以及 和ReenTrantLock 的对比...
  • webpack+react项目初体验——记录我的webpack环境配置
  • 基于webpack 的 vue 多页架构
  • 两列自适应布局方案整理
  • 前端面试总结(at, md)
  • 使用docker-compose进行多节点部署
  • 它承受着该等级不该有的简单, leetcode 564 寻找最近的回文数
  • 微信小程序设置上一页数据
  • 在Docker Swarm上部署Apache Storm:第1部分
  • CMake 入门1/5:基于阿里云 ECS搭建体验环境
  • HanLP分词命名实体提取详解
  • 函数计算新功能-----支持C#函数
  • # Kafka_深入探秘者(2):kafka 生产者
  • # 学号 2017-2018-20172309 《程序设计与数据结构》实验三报告
  • # 执行时间 统计mysql_一文说尽 MySQL 优化原理
  • #android不同版本废弃api,新api。
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • #微信小程序:微信小程序常见的配置传值
  • #我与Java虚拟机的故事#连载02:“小蓝”陪伴的日日夜夜
  • (11)MATLAB PCA+SVM 人脸识别
  • (DenseNet)Densely Connected Convolutional Networks--Gao Huang
  • (ISPRS,2023)深度语义-视觉对齐用于zero-shot遥感图像场景分类
  • (Mac上)使用Python进行matplotlib 画图时,中文显示不出来
  • (动态规划)5. 最长回文子串 java解决
  • (翻译)Entity Framework技巧系列之七 - Tip 26 – 28
  • (附源码)ssm高校实验室 毕业设计 800008
  • (剑指Offer)面试题34:丑数