当前位置: 首页 > news >正文

parameters()函数 --- 获取模型参数量

  parameters() 函数是 PyTorch 中 torch.nn.Module 类的一个方法,用于返回模型中所有可训练的参数。下面是对这个函数的详细解释:

1. parameters() 方法工作机制

parameters() 方法工作机制:定义一个模型,通常会将多个层(如卷积层、线性层等)组合在一起,这些层就是主模块的子模块。而这些子模块中有些也可能包含自己的子模块,形成一个递归的层次结构。parameters() 方法会自动遍历整个层次结构,获取每个模块和子模块中的可训练参数。

2. 返回值

  • 返回一个迭代器,其中包含所有可训练的参数。

 3. 示例

import torch
import torch.nn as nn# 定义一个包含多个子模块的模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # 子模块1self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # 子模块2self.fc = nn.Linear(32 * 16 * 16, 10)  # 子模块3def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)  # Flattening for the fully connected layerx = self.fc(x)return x# 实例化模型
model = MyModel()# 获取模型的参数
for param in model.parameters():print(param.size())

      上述代码中Conv2d函数bias没有指定值,则取默认值bias=True,也就是说,在上述代码中,bias是存在的。

分析:

  • MyModel 包含了 3 个层:两个卷积层(conv1conv2)以及一个全连接层(fc)。
  • 这些层都是 MyModel子模块
  • 当你调用 model.parameters() 时,它不仅返回 MyModel 自己的参数,还会递归地返回所有子模块(即 conv1conv2fc)的参数。

3. 递归参数获取

parameters() 函数递归地遍历模块中的所有子模块,获取每个模块的参数。每个 nn.Module 对象的 parameters() 方法会遍历它自己和它的所有子模块,将所有可训练参数打包在一起。

你可以通过以下代码验证这一点:

# 打印每个模块的参数名和大小 
for name, param in model.named_parameters(): print(f"参数名: {name}, 大小: {param.size()}")

输出:

参数名: conv1.weight, 大小: torch.Size([16, 3, 3, 3]) 
参数名: conv1.bias, 大小: torch.Size([16]) 
参数名: conv2.weight, 大小: torch.Size([32, 16, 3, 3]) 
参数名: conv2.bias, 大小: torch.Size([32]) 
参数名: fc.weight, 大小: torch.Size([10, 8192]) 
参数名: fc.bias, 大小: torch.Size([10])

这里你会看到,不仅 MyModel 中的卷积层 conv1conv2 的权重和偏置参数被列出,连全连接层 fc 的参数也被列出了。

4. 参数过滤(过滤出可训练的参数)

可以通过条件过滤来获取特定类型的参数,例如仅获取可训练的参数:

trainable_params = [p for p in model.parameters() if p.requires_grad]

5. 计算参数总量

可以结合 numel() 方法来计算模型的参数总量:

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 
print(f"总参数量: {total_params}")

在这个示例中,p.numel() 返回参数的元素数量,if p.requires_grad 确保只计算需要梯度的参数(即可训练的参数)。运行这个代码将输出模型的参数总量。 

 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • ConcurrentHashMap的使用
  • 如何选择光伏业务监管系统软件
  • 2024.09.18 leetcode 每日一题
  • 排序算法C++
  • AWS EKS 中的负载均衡和 TLS 配置:全面指南
  • Matplotlib-数据可视化详解
  • QT| QT配置CUDA
  • R语言APSIM模型进阶应用与参数优化、批量模拟实践技术
  • C++(学习)2024.9.23
  • ubuntu如何进行切换内核版本全教程
  • LLM - 理解 多模态大语言模型(MLLM) 的 幻觉(Hallucination) 与相关技术 (七)
  • leetcode91. 解码方法,动态规划
  • 最新Kali Linux超详细安装教程(附镜像包)
  • 『C/C++』整型和字符串相互转换
  • itextsharp报错 PdfReader not opened with owner password
  • 【干货分享】SpringCloud微服务架构分布式组件如何共享session对象
  • 【跃迁之路】【733天】程序员高效学习方法论探索系列(实验阶段490-2019.2.23)...
  • 0x05 Python数据分析,Anaconda八斩刀
  • CAP理论的例子讲解
  • ES10 特性的完整指南
  • gitlab-ci配置详解(一)
  • IOS评论框不贴底(ios12新bug)
  • iOS筛选菜单、分段选择器、导航栏、悬浮窗、转场动画、启动视频等源码
  • JavaScript函数式编程(一)
  • JDK 6和JDK 7中的substring()方法
  • magento 货币换算
  • Mybatis初体验
  • PAT A1050
  • Phpstorm怎样批量删除空行?
  • Python3爬取英雄联盟英雄皮肤大图
  • Redis的resp协议
  • spring security oauth2 password授权模式
  • SQLServer插入数据
  • vue--为什么data属性必须是一个函数
  • Vue源码解析(二)Vue的双向绑定讲解及实现
  • Web Storage相关
  • web标准化(下)
  • 阿里云爬虫风险管理产品商业化,为云端流量保驾护航
  • 不发不行!Netty集成文字图片聊天室外加TCP/IP软硬件通信
  • 初识MongoDB分片
  • 初探 Vue 生命周期和钩子函数
  • 猴子数据域名防封接口降低小说被封的风险
  • 解决jsp引用其他项目时出现的 cannot be resolved to a type错误
  • 如何打造100亿SDK累计覆盖量的大数据系统
  • 深入 Nginx 之配置篇
  • 微信小程序实战练习(仿五洲到家微信版)
  • ​【数据结构与算法】冒泡排序:简单易懂的排序算法解析
  • ​1:1公有云能力整体输出,腾讯云“七剑”下云端
  • ​软考-高级-信息系统项目管理师教程 第四版【第23章-组织通用管理-思维导图】​
  • # windows 安装 mysql 显示 no packages found 解决方法
  • # 数仓建模:如何构建主题宽表模型?
  • ## 临床数据 两两比较 加显著性boxplot加显著性
  • ###51单片机学习(2)-----如何通过C语言运用延时函数设计LED流水灯
  • #pragma预处理命令
  • (javascript)再说document.body.scrollTop的使用问题