当前位置: 首页 > news >正文

Pytorch代码:打印模型每层的参数数量和总参数量

这个代码片段定义了一个函数 print_model_parameters,它的作用是打印每层的参数数量以及模型的总参数量。下面是对这个函数的详细解释,重点解释 named_parametersrequires_gradnumel 参数的含义:

# 打印每层的参数数量和总参数量
def print_model_parameters(model):total_params = 0for name, param in model.named_parameters():if param.requires_grad:print(f"{name}: {param.numel()} parameters")total_params += param.numel()print(f"For now parameters: {total_params}")print(f"Total parameters: {total_params}")

具体步骤和解释

  1. 定义和初始化

    def print_model_parameters(model):total_params = 0
    

    这个函数接收一个模型对象 model,并初始化一个变量 total_params 用于累积总参数量。

  2. 遍历模型参数

    for name, param in model.named_parameters():
    

    这里使用了 model.named_parameters() 方法,该方法返回一个生成器,生成模型中所有参数的名称和参数张量。它返回的是 (name, parameter) 形式的元组。

    • named_parameters:这是一个PyTorch模型的方法,它返回模型中所有参数的名称和参数本身。参数的名称是字符串类型,而参数是一个 torch.Tensor 对象。
  3. 判断参数是否需要梯度更新

    if param.requires_grad:
    

    每个参数张量都有一个 requires_grad 属性,这个属性是一个布尔值。如果 requires_gradTrue,表示这个参数在训练过程中需要计算梯度并进行更新。

    • requires_grad:这是一个布尔值属性,表示该参数是否需要在训练过程中计算梯度。如果是 True,则该参数会在反向传播时计算并存储梯度。
  4. 打印参数数量并累加

    print(f"{name}: {param.numel()} parameters")
    total_params += param.numel()
    print(f"For now parameters: {total_params}")
    

    对于需要梯度的参数,打印其名称和参数数量,并将该参数的数量累加到 total_params 中。

    • numel:这是一个方法,返回张量中所有元素的数量。例如,一个形状为 (3, 4) 的张量调用 numel() 方法会返回 12,因为这个张量有12个元素。
  5. 打印总参数量

    print(f"Total parameters: {total_params}")
    

    最后,打印模型的总参数数量。

总结

这个函数通过 model.named_parameters() 遍历模型的所有参数,检查每个参数的 requires_grad 属性,只有在 requires_gradTrue 时才计算并打印参数数量,同时累加总参数量。 numel() 方法用于获取每个参数张量的元素数量,从而帮助统计参数数量。最后打印总参数量,提供了对模型规模的一个直观了解。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 计算机基础(Windows 10+Office 2016)教程 —— 第5章 文档编辑软件Word 2016(下)
  • 机械学习—零基础学习日志(高数22——泰勒公式理解深化)
  • 初识云计算
  • AR眼镜:重型机械维修保养新利器
  • LVS-DR模式集群:案例与概念
  • vue项目部署在子路径中前端配置
  • 运输层 可靠数据传输原理——1、构造可靠数据传输协议
  • 尚硅谷谷粒商城项目笔记——五、使用docker安装mysql
  • 深入理解Vue slot的原理
  • K8S中Containerd之ctr和crictl简介以及常见操作
  • vs+qt项目转qt creator
  • 云原生真机实验
  • 高翔【自动驾驶与机器人中的SLAM技术】学习笔记(五)卡尔曼滤波器一:认知卡尔曼滤波器;协方差矩阵与方差;
  • 一个java类实现UDP代理转发
  • MySQL--查询数据
  • 实现windows 窗体的自己画,网上摘抄的,学习了
  • [译] 理解数组在 PHP 内部的实现(给PHP开发者的PHP源码-第四部分)
  • 【跃迁之路】【733天】程序员高效学习方法论探索系列(实验阶段490-2019.2.23)...
  • 30秒的PHP代码片段(1)数组 - Array
  • Android 控件背景颜色处理
  • Asm.js的简单介绍
  • centos安装java运行环境jdk+tomcat
  • chrome扩展demo1-小时钟
  • CODING 缺陷管理功能正式开始公测
  • js递归,无限分级树形折叠菜单
  • learning koa2.x
  • node和express搭建代理服务器(源码)
  • Phpstorm怎样批量删除空行?
  • PHP面试之三:MySQL数据库
  • SpiderData 2019年2月13日 DApp数据排行榜
  • STAR法则
  • vagrant 添加本地 box 安装 laravel homestead
  • vue总结
  • Xmanager 远程桌面 CentOS 7
  • 对象管理器(defineProperty)学习笔记
  • - 概述 - 《设计模式(极简c++版)》
  • 融云开发漫谈:你是否了解Go语言并发编程的第一要义?
  • 设计模式(12)迭代器模式(讲解+应用)
  • 微信端页面使用-webkit-box和绝对定位时,元素上移的问题
  • 学习Vue.js的五个小例子
  • 由插件封装引出的一丢丢思考
  • 主流的CSS水平和垂直居中技术大全
  • 关于Android全面屏虚拟导航栏的适配总结
  • ​io --- 处理流的核心工具​
  • "无招胜有招"nbsp;史上最全的互…
  • ######## golang各章节终篇索引 ########
  • #100天计划# 2013年9月29日
  • #define与typedef区别
  • (2024)docker-compose实战 (8)部署LAMP项目(最终版)
  • (ros//EnvironmentVariables)ros环境变量
  • (TOJ2804)Even? Odd?
  • (vue)el-checkbox 实现展示区分 label 和 value(展示值与选中获取值需不同)
  • (超简单)使用vuepress搭建自己的博客并部署到github pages上
  • (初研) Sentence-embedding fine-tune notebook
  • (代码示例)使用setTimeout来延迟加载JS脚本文件