当前位置: 首页 > news >正文

pytorch笔记:named_parameters

  • named_parameters 是 PyTorch 中一个非常有用的函数,用于访问模型中所有定义的参数及其对应的名称。
  • 它是 torch.nn.Module 类的方法之一,返回一个生成器,生成 (name, parameter) 对,name 是参数的名称,parameter 是对应的参数张量。

1 举例

1.0 创建模型


import torch
import torch.nn as nn# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 64, 5)self.fc1 = nn.Linear(64 * 4 * 4, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))x = x.view(-1, 64 * 4 * 4)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型
model_tst = SimpleModel()

1.1 应用1:打印模型的所有参数及其名称

for name, param in model_tst.named_parameters():print(name, param.shape)'''
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
conv1.weight torch.Size([20, 1, 5, 5])
conv1.bias torch.Size([20])
conv2.weight torch.Size([64, 20, 5, 5])
conv2.bias torch.Size([64])
fc1.weight torch.Size([500, 1024])
fc1.bias torch.Size([500])
fc2.weight torch.Size([10, 500])
fc2.bias torch.Size([10])
'''

1.2 应用2:冻结特定层的参数

假设我们只想训练全连接层,而冻结卷积层的参数:

for name, param in model_tst.named_parameters():if 'conv' in name:param.requires_grad = False

1.3 应用3:自定义优化器参数

可以使用 named_parameters 创建自定义的参数组,以便对不同的参数组应用不同的学习率:

optimizer = torch.optim.SGD([{'params': [param for name, param in model_tst.named_parameters() if 'conv' in name], 'lr': 0.01},{'params': [param for name, param in model_tst.named_parameters() if 'fc' in name], 'lr': 0.1}
], momentum=0.9)

相关文章:

  • springboot 集成阿里云 OSS
  • 41、web基础和http协议
  • SpringMVC系列二: 请求方式介绍
  • 电脑系统重装怎么操作?分享四个win10重装系统方法
  • 更改ip后还被封是ip质量的原因吗?
  • DDei在线设计器-API-DDeiSheet
  • Discuz动漫二次元风格网站模板
  • [经验] candy是什么意思英语翻译 #笔记#其他#职场发展
  • AIGC发展方向和前景
  • 晨持绪科技:开好一家抖音小店运营怎么做
  • 未来几年大多数人会面临的困境
  • 软件工程考试题备考
  • 应用分发也叫APP分发
  • mysql 查询排名,包括并列排名和连续排名
  • 阿里云PAI大模型评测最佳实践
  • 【108天】Java——《Head First Java》笔记(第1-4章)
  • 【RocksDB】TransactionDB源码分析
  • 2018以太坊智能合约编程语言solidity的最佳IDEs
  • Android路由框架AnnoRouter:使用Java接口来定义路由跳转
  • DataBase in Android
  • eclipse的离线汉化
  • ECS应用管理最佳实践
  • EventListener原理
  • Flannel解读
  • gitlab-ci配置详解(一)
  • IE报vuex requires a Promise polyfill in this browser问题解决
  • Invalidate和postInvalidate的区别
  • JavaScript实现分页效果
  • jquery ajax学习笔记
  • JS专题之继承
  • Laravel核心解读--Facades
  • Netty+SpringBoot+FastDFS+Html5实现聊天App(六)
  • react-core-image-upload 一款轻量级图片上传裁剪插件
  • SpiderData 2019年2月13日 DApp数据排行榜
  • Transformer-XL: Unleashing the Potential of Attention Models
  • 案例分享〡三拾众筹持续交付开发流程支撑创新业务
  • 半理解系列--Promise的进化史
  • 从@property说起(二)当我们写下@property (nonatomic, weak) id obj时,我们究竟写了什么...
  • 看域名解析域名安全对SEO的影响
  • 聊一聊前端的监控
  • 盘点那些不知名却常用的 Git 操作
  • 山寨一个 Promise
  • 无服务器化是企业 IT 架构的未来吗?
  • [地铁译]使用SSD缓存应用数据——Moneta项目: 低成本优化的下一代EVCache ...
  • 通过调用文摘列表API获取文摘
  • !!Dom4j 学习笔记
  • (3)(3.5) 遥测无线电区域条例
  • (笔试题)合法字符串
  • (二)springcloud实战之config配置中心
  • (转)关于如何学好游戏3D引擎编程的一些经验
  • *p++,*(p++),*++p,(*p)++区别?
  • .gitignore
  • .NET Core MongoDB数据仓储和工作单元模式封装
  • .NET Micro Framework初体验
  • .NET Micro Framework初体验(二)