当前位置: 首页 > news >正文

transformers进行学习率调整lr_scheduler(warmup)

一、get_scheduler实现warmup

1、warmup基本思想

Warmup(预热)是深度学习训练中的一种技巧,旨在逐步增加学习率以稳定训练过程,特别是在训练的早期阶段。它主要用于防止在训练初期因学习率过大导致的模型参数剧烈波动或不稳定。预热阶段通常是指在训练开始时,通过多个步长逐步将学习率从一个较低的值增加到目标值(通常是预定义的最大学习率)。

2、warmup基本实现

from transformers import get_schedulerscheduler = get_scheduler(name="cosine",  # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)#linear:线性学习率下降
#cosine:余弦退火
#polynomial:多项式衰减
#constant:常数学习率
#constant_with_warmup:预热后保持常数# 上述代码等价于
from transformers import get_cosine_scheduler_with_warmupscheduler = get_cosine_scheduler_with_warmup(optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)# 同理等价于linear, polynomial, constant分别等价于
from transformers import (get_constant_schedule, get_polynomial_decay_schedule_with_warmup, get_linear_schedule_with_warmup)

 二、各种warmup策略学习率变化规律

1、get_constant_schedule学习率变化规律

2、get_cosine_schedule_with_warmup学习率变化规律

3、get_cosine_with_hard_restarts_schedule_with_warmup学习率变化规律

4、get_linear_schedule_with_warmup学习率变化规律

5、get_polynomial_decay_schedule_with_warmup学习率变化规律(power=2, power=1类似于linear)

6、注意事项

  • 如果网络中不同框架采用不同的学习率,上述的warmup策略仍然有效(如图二、5中所示) 
  • 给schduler设置的number_training_steps一定要和训练过程相匹配,如下所示。

7、可视化学习率过程

import matplotlib.pyplot as plt
from transformers import get_scheduler
from torch.optim import AdamW
import torch
import math# 定义一些超参数learning_rate = 1e-3  # 初始学习率# 假设有一个模型
model = torch.nn.Linear(10, 2)# 获得训练总的步数
epochs = 50
batch_size = 32
#train_loader = ***
#num_train_loader = len(train_loader)
num_train_loader = 1235num_training_steps = epochs * math.ceil(num_train_loader/batch_size) # 总的训练步数# 定义优化器
optimizer = AdamW(model.parameters(), lr=learning_rate)# 创建学习率调度器
scheduler = get_scheduler(name="cosine",  # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100,  # 预热步数num_training_steps=num_training_steps  # 总的训练步数
)# 存储每一步的学习率
learning_rates = []# for step in range(num_training_steps):
#    optimizer.step()
#    scheduler.step()
#    learning_rates.append(optimizer.param_groups[0]['lr'])for epoch in range(epochs):# for batch in train_loader:for step in range(0, num_train_loader, batch_size):optimizer.zero_grad()# loss.backward()optimizer.step()scheduler.step()learning_rates.append(optimizer.param_groups[0]['lr'])# 绘制学习率曲线
plt.plot(learning_rates)
plt.xlabel("Training Steps")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Schedule")
plt.show()

实验结果:

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【详细的springboot自动装载原理】
  • 异常处理和swagger使用
  • Vue3时间选择器datetimerange在数据库存开始时间和结束时间
  • 人工智能技术的分析与探讨
  • Ubuntu-文件管理器中鼠标右键添加文本文件
  • 测试面试宝典(二十八)—— 请问黑盒测试和白盒测试有哪些方法?
  • Android 生成Excel并导出全流程
  • 【JAVA】Hutool CollUtil.sort 方法:多场景下的排序解决方案
  • MYSQL 第四次作业
  • 函数调用时参数是如何从右至左入栈的
  • MySQL-视图、存储过程和触发器
  • HTML常用的转义字符——怎么在网页中写“<div></div>”?
  • ArcGIS Pro SDK (九)几何 15 转换
  • 6.3 面向对象技术-设计模式
  • C++ - 基于多设计模式下的同步异步⽇志系统
  • 《Javascript高级程序设计 (第三版)》第五章 引用类型
  • 【Linux系统编程】快速查找errno错误码信息
  • 【附node操作实例】redis简明入门系列—字符串类型
  • 【译】React性能工程(下) -- 深入研究React性能调试
  • 【跃迁之路】【444天】程序员高效学习方法论探索系列(实验阶段201-2018.04.25)...
  • Android框架之Volley
  • Git 使用集
  • Javascript基础之Array数组API
  • nodejs调试方法
  • Object.assign方法不能实现深复制
  • Puppeteer:浏览器控制器
  • Python socket服务器端、客户端传送信息
  • session共享问题解决方案
  • Spring技术内幕笔记(2):Spring MVC 与 Web
  • text-decoration与color属性
  • Vue.js源码(2):初探List Rendering
  • vue:响应原理
  • 从@property说起(二)当我们写下@property (nonatomic, weak) id obj时,我们究竟写了什么...
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 分享自己折腾多时的一套 vue 组件 --we-vue
  • 机器人定位导航技术 激光SLAM与视觉SLAM谁更胜一筹?
  • 理清楚Vue的结构
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 什么软件可以提取视频中的音频制作成手机铃声
  • Java总结 - String - 这篇请使劲喷我
  • 移动端高清、多屏适配方案
  • 正则表达式-基础知识Review
  • ​ssh免密码登录设置及问题总结
  • ​比特币大跌的 2 个原因
  • #pragam once 和 #ifndef 预编译头
  • #数据结构 笔记三
  • (附源码)php新闻发布平台 毕业设计 141646
  • (紀錄)[ASP.NET MVC][jQuery]-2 純手工打造屬於自己的 jQuery GridView (含完整程式碼下載)...
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (学习日记)2024.03.12:UCOSIII第十四节:时基列表
  • (一)kafka实战——kafka源码编译启动
  • (原創) 博客園正式支援VHDL語法著色功能 (SOC) (VHDL)
  • (转) SpringBoot:使用spring-boot-devtools进行热部署以及不生效的问题解决
  • (转)Oracle 9i 数据库设计指引全集(1)
  • (转)甲方乙方——赵民谈找工作