当前位置: 首页 > news >正文

Pytorch学习率lr衰减(decay)(scheduler)

1、手动修改optimizer中的lr

import matplotlib.pyplot as plt
from torch import nn
import torch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Linear(10,10)
    def forward(self, input):
        out = self.net(input)
        return out

model = Net()
LR = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr = LR)
lr_list = []
for epoch in range(100):
    if epoch % 5 == 0:
        for p in optimizer.param_groups:
            p['lr'] *= 0.9#注意这里
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100), lr_list, color = 'r')
plt.show()

2. 动态调整学习率

torch.optim.lr_scheduler

torch.optim.lr_scheduler上,基于当前epoch的数值,为我们封装了几种相应的动态学习率调整方法

① lr_scheduler.LambdaLR

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

lr_lambda 会接收到一个int参数:epoch,然后根据epoch计算出对应的lr。如果设置多个lambda函数的话,会分别作用于Optimizer中的不同的params_group

import matplotlib.pyplot as plt
from torch import nn
import torch
from torch import optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Linear(10,10)
    def forward(self, input):
        out = self.net(input)
        return out

import numpy as np 
lr_list = []
model = Net()
LR = 0.01
optimizer = optim.Adam(model.parameters(),lr = LR)

lambda1 = lambda epoch:np.sin(epoch) / epoch
scheduler = optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)

for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
plt.show()

 ② lr_scheduler.StepLR 阶梯式衰减

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

每个一定的epoch,lr会自动乘以gamma进行阶梯式衰减

注意:pytorch1.1.0之后scheduler.step()要放在optimizer.step()之后!!!

import matplotlib.pyplot as plt
from torch import nn
import torch
from torch import optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Linear(10,10)
    def forward(self, input):
        out = self.net(input)
        return out

lr_list = []
model = Net()
LR = 0.01
optimizer = optim.Adam(model.parameters(),lr = LR)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma = 0.8)
for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
plt.show()

③ lr_scheduler.MultiStepLR——多阶梯式衰减

三段式lr,epoch进入milestones范围内即乘以gamma,离开milestones范围之后再乘以gamma。这种衰减方式也是在学术论文中最常见的方式,一般手动调整也会采用这种方法。

import matplotlib.pyplot as plt
from torch import nn
import torch
from torch import optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Linear(10,10)
    def forward(self, input):
        out = self.net(input)
        return out

lr_list = []
model = Net()
LR = 0.01
optimizer = optim.Adam(model.parameters(),lr = LR)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,milestones=[20,80],gamma = 0.9)
for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
plt.show()

④ExponentialLR——指数连续衰减

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)

 每个epoch中lr都乘以gamma

import matplotlib.pyplot as plt
from torch import nn
import torch
from torch import optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Linear(10,10)
    def forward(self, input):
        out = self.net(input)
        return out

lr_list = []
model = Net()
LR = 0.01
optimizer = optim.Adam(model.parameters(),lr = LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.plot(range(100),lr_list,color = 'r')
plt.show()

⑤ ReduceLROnPlateau

在发现loss不再降低或者acc不再提高之后,降低学习率。

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

各参数意义如下:

mode:'min'模式检测metric是否不再减小,'max'模式检测metric是否不再增大;

factor: 触发条件后lr*=factor;

patience:不再减小(或增大)的累计次数;

verbose:触发条件后print;

threshold:只关注超过阈值的显著变化;

threshold_mode:有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著;

cooldown:触发一次条件后,等待一定epoch再进行检测,避免lr下降过速;

min_lr:最小的允许lr;

eps:如果新旧lr之间的差异小与1e-8,则忽略此次更新。

class torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8)

# patience=10代表的是耐心值为10,
# 当loss出现10次不变化时,即开始调用learning rate decat功能
optimizer = torch.optim.SGD(model.parameters(),
                            args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
scheduler = ReduceLROnPlateau(optimizer, 'min')
# min代表希望的目标减少的loss

scheduler.step(loss_val)
# 设置监听的是loss

相关文章:

  • Pytorch获取中间变量的梯度grad
  • Pytorch梯度裁剪 nn.utils.clip_grad_norm_()
  • Layer Normalization(LN) 层标准化
  • TF_CPP_MIN_LOG_LEVEL
  • Python sys.argv
  • pytorch模型可复现设置(cudnn.benchmark 加速卷积运算 cudnn.deterministic)
  • Python sys.stdout
  • Python vars()函数
  • Python类的self
  • Python输出numpy array带逗号和不带逗号
  • center loss 中心损失
  • torch与lua的关系
  • Python类super(super().__init__())
  • 自回归模型(Autoregressive model)(auto)
  • Pytorch tensorboard与tensorboardX的区别
  • 【面试系列】之二:关于js原型
  • 【前端学习】-粗谈选择器
  • extract-text-webpack-plugin用法
  • jdbc就是这么简单
  • k8s如何管理Pod
  • python学习笔记 - ThreadLocal
  • SegmentFault 社区上线小程序开发频道,助力小程序开发者生态
  • uni-app项目数字滚动
  • Work@Alibaba 阿里巴巴的企业应用构建之路
  • 程序员最讨厌的9句话,你可有补充?
  • 代理模式
  • 分享自己折腾多时的一套 vue 组件 --we-vue
  • 前端面试之CSS3新特性
  • 前嗅ForeSpider中数据浏览界面介绍
  • 如何用Ubuntu和Xen来设置Kubernetes?
  • 网页视频流m3u8/ts视频下载
  • 学习HTTP相关知识笔记
  • 移动端 h5开发相关内容总结(三)
  • 异步
  • Redis4.x新特性 -- 萌萌的MEMORY DOCTOR
  • Spark2.4.0源码分析之WorldCount 默认shuffling并行度为200(九) ...
  • !!java web学习笔记(一到五)
  • #define与typedef区别
  • #我与Java虚拟机的故事#连载03:面试过的百度,滴滴,快手都问了这些问题
  • (173)FPGA约束:单周期时序分析或默认时序分析
  • (Matalb回归预测)PSO-BP粒子群算法优化BP神经网络的多维回归预测
  • (poj1.2.1)1970(筛选法模拟)
  • (Redis使用系列) SpirngBoot中关于Redis的值的各种方式的存储与取出 三
  • (附表设计)不是我吹!超级全面的权限系统设计方案面世了
  • (机器学习-深度学习快速入门)第三章机器学习-第二节:机器学习模型之线性回归
  • (十五)Flask覆写wsgi_app函数实现自定义中间件
  • (四)JPA - JQPL 实现增删改查
  • (一) storm的集群安装与配置
  • (转)C#开发微信门户及应用(1)--开始使用微信接口
  • (转)scrum常见工具列表
  • (转)使用VMware vSphere标准交换机设置网络连接
  • (总结)Linux下的暴力密码在线破解工具Hydra详解
  • ./configure,make,make install的作用(转)
  • .NET Core使用NPOI导出复杂,美观的Excel详解
  • .net Signalr 使用笔记