当前位置: 首页 > news >正文

pytorch 笔记:torch.optim(基类的基本操作)

1 基本用法

1.1 构建优化器

  • 要构建一个优化器,需要提供一个包含参数的可迭代对象来进行优化。
  • 然后,可以指定特定于优化器的选项,如学习率、权重衰减等
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

1.2 按参数设置选项

  • 不是传递一个 Variable 的可迭代对象,而是传递一个字典的可迭代对象
  • 每个字典定义一个单独的参数组,并且应包含一个 params 键,包含属于该组的参数列表
optim.SGD([{'params': model.base.parameters(), 'lr': 1e-2},{'params': model.classifier.parameters()}
], lr=1e-3, momentum=0.9)
  • 对于 model.base.parameters(),设置了学习率为 1e-2【相当于覆盖了默认的1e-3】
  • 对于 model.classifier.parameters(),没有单独设置学习率,因此它会使用构造 optim.SGD 时设定的默认学习率,这里是 1e-3
  • 对于这边所有的参数,都有0.9的动量

另一个例子:

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]
#把参数中的偏置和非偏置分开optim.SGD([{'params': others},{'params': bias_params, 'weight_decay': 0}
], weight_decay=1e-2, lr=1e-2)
#为偏置项设置了 0 的 weight_decay

1.3 进行优化步骤

所有优化器都实现了一个 step() 方法,用于更新参数

optimizer.step()

这是大多数优化器支持的简化版本。一旦使用例如 backward() 计算了梯度,就可以调用该函数。

2 add_param_group

  • torch.optim.Optimizer.add_param_group 方法是一个很有用的功能,允许你在训练过程中向优化器添加新的参数组。
  • 这在微调预训练网络时特别有用,例如当你决定解冻某些层,让它们在训练过程中变得可训练,并需要被优化器管理
    • 当你开始训练时,可能有些网络层是冻结的(不更新权重)。随着训练的进行,你可能想要解冻一些层以进行微调
    • 通过使用 add_param_group 方法,可以在不重启训练和不重新创建优化器的情况下,将这些新解冻的层添加到优化器中进行优化。

参数:

param_group (dict)

这个字典指定了应该被优化的张量(Tensors),以及与该参数组相关的特定优化选项。

这些选项可以包括学习率 (lr)、动量 (momentum)、权重衰减 (weight_decay) 等

举例:

假设有一个已经训练的模型,现在在特定的epoch之后,要解冻最后几层以进行微调


# 假设model是你的模型,model.classifier是需要最后解冻的部分optimizer = optim.SGD([param for name, param in model.named_parameters() if 'classifier' not in name], lr=0.01, momentum=0.9)
# 初始化优化器,先只包括除classifier之外的其他部分# 训练模型
num_epochs = 20
for epoch in range(num_epochs):if epoch == 10:   for param in model.classifier.parameters():param.requires_grad = True# 到达第10个epoch,解冻classifier部分的参数optimizer.add_param_group({'params': model.classifier.parameters(), 'lr': 0.001})# 添加新的参数组到优化器for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()# 正常的一套训练过程

3 state_dict

  • torch.optim.Optimizer.state_dict() 方法返回一个字典,描述了优化器的当前状态。
  • 这是PyTorch中模型保存和加载机制的重要组成部分,特别是在需要暂停和恢复训练过程时
  • 返回的字典包含两个主要条目:

state
  • 一个字典,保存当前优化状态。
  • 内容因不同的优化器类别而异,但有一些共同的特征。
    • 例如,状态是按参数保存的,而参数本身并不被保存。
    • 这个字典将参数ID映射到与每个参数相关的状态字典。
param_groups
  • 一个列表,包含所有参数组,每个参数组都是一个字典。
  • 每个参数组包含了特定于优化器的元数据,如学习率和权重衰减,以及该组中参数的参数ID列表
{"state": {0: {"momentum_buffer": tensor(...), ...},1: {"momentum_buffer": tensor(...), ...},2: {"momentum_buffer": tensor(...), ...},3: {"momentum_buffer": tensor(...), ...}},"param_groups": [{"lr": 0.01,"weight_decay": 0,..."params": [0]},{"lr": 0.001,"weight_decay": 0.5,..."params": [1, 2, 3]}]
}

4 load_state_dict

  • torch.optim.Optimizer.load_state_dict 方法用于加载优化器的状态。这对于恢复训练过程或从先前的训练状态继续训练非常有用。
# 假设在训练过程中,我们保存了优化器的状态
saved_state_dict = optimizer.state_dict()# 现在我们需要重新加载优化器的状态
optimizer.load_state_dict(saved_state_dict)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【Mode Management】ECU上下电过程CanSM为什么会多次设置CandTrcv和CanController模式
  • MySQL学习作业二
  • 计算机组成原理面试知识点总结1
  • git使用以及理解
  • CSPVD 智慧工地安全帽安全背心检测开发包
  • 代码随想录学习 day54 图论 Bellman_ford 队列优化算法(又名SPFA) 学习
  • WebKit 引擎:CSS 悬停效果的魔法师
  • “论系统安全架构设计及其应用”,写作框架,软考高级论文,系统架构设计师论文
  • Grafana :利用Explore方式实现多条件查询
  • python基础语法 007 文件操作-2文件支持模式文件的内置函数
  • 数据库基础与安装MYSQL数据库
  • 解决云服务器CPU占用率接近100%问题
  • 二叉树基础及实现(一)
  • Java 写一个可以持续发送消息的socket服务端
  • c++初阶篇(三):内联函数及auto关键字
  • 分享的文章《人生如棋》
  • Android 架构优化~MVP 架构改造
  • GDB 调试 Mysql 实战(三)优先队列排序算法中的行记录长度统计是怎么来的(上)...
  • IndexedDB
  • js操作时间(持续更新)
  • JS正则表达式精简教程(JavaScript RegExp 对象)
  • JWT究竟是什么呢?
  • Laravel 菜鸟晋级之路
  • mysql中InnoDB引擎中页的概念
  • Next.js之基础概念(二)
  • PHP 程序员也能做的 Java 开发 30分钟使用 netty 轻松打造一个高性能 websocket 服务...
  • rabbitmq延迟消息示例
  • vue-loader 源码解析系列之 selector
  • 包装类对象
  • 构建工具 - 收藏集 - 掘金
  • 技术攻略】php设计模式(一):简介及创建型模式
  • 坑!为什么View.startAnimation不起作用?
  • 前端知识点整理(待续)
  • 深入浅出webpack学习(1)--核心概念
  • 用Visual Studio开发以太坊智能合约
  • 第二十章:异步和文件I/O.(二十三)
  • ​RecSys 2022 | 面向人岗匹配的双向选择偏好建模
  • $nextTick的使用场景介绍
  • (09)Hive——CTE 公共表达式
  • (27)4.8 习题课
  • (C语言)字符分类函数
  • (二)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (二)延时任务篇——通过redis的key监听,实现延迟任务实战
  • (论文阅读30/100)Convolutional Pose Machines
  • (四)linux文件内容查看
  • (限时免费)震惊!流落人间的haproxy宝典被找到了!一切玄妙尽在此处!
  • (学习日记)2024.04.10:UCOSIII第三十八节:事件实验
  • (一)Dubbo快速入门、介绍、使用
  • (转)C语言家族扩展收藏 (转)C语言家族扩展
  • (轉)JSON.stringify 语法实例讲解
  • (自用)仿写程序
  • .MyFile@waifu.club.wis.mkp勒索病毒数据怎么处理|数据解密恢复
  • .NET 4.0网络开发入门之旅-- 我在“网” 中央(下)
  • .net core 连接数据库,通过数据库生成Modell
  • .NET Core 网络数据采集 -- 使用AngleSharp做html解析