当前位置: 首页 > news >正文

【PyTorch】使用容器(Containers)进行网络层管理(Module)

文章目录

  • 前言
  • 一、Sequential
  • 二、ModuleList
  • 三、ModuleDict
  • 四、ParameterList & ParameterDict
  • 总结


前言

当深度学习模型逐渐变得复杂,在编写代码时便会遇到诸多麻烦,此时便需要Containers的帮助。Containers的作用是将一部分网络层模块化,从而更方便地管理和调用。本文介绍PyTorch库常用的nn.Sequential,nn.ModuleList,nn.ModuleDict容器以及nn.ParameterList & ParameterDict参数容器。


一、Sequential

Sequential是最为常用的容器,它的功能也十分简单直接-将多个网络层按照固定的顺序连接,从前往后依次执行。比如在AlexNet中,多次需要conv+relu+maxpool的组合,此时便可以将其放入Sequential容器,便于在forward中调用。
下面来看PyTorch官方代码示例:

model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())# Using Sequential with OrderedDict. This is functionally the# same as the above codemodel = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))

示例中展示了两种Sequential使用方法:1,直接串联各个网络层。2,使用OrderedDict为每个module取名。这两种方法是等效的。


二、ModuleList

"顾名思义"ModuleList的作用如同Python的列表,将各个层存入一个类似于List的结构中,从而可以利用索引来进行调用。
注意这里是类似于list的结构,那为什么我们不直接用list呢?
ModuleList是专门为Pytorch中的神经网络模块(即继承自nn.Module的类)设计的容器。它确保所有添加到其中的模块都会正确地注册到网络中,以便进行参数管理和梯度更新。当模型被保存或加载时,nn.ModuleList中的模块也会相应地被保存或加载。而Python的列表是一个通用的容器,可以存储任意类型的对象。它没有专门为神经网络模块设计,因此不会进行参数的自动注册或管理。
代码示例:

class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])# self.linears = [nn.Linear(10, 10) for i in range(10)]    def forward(self, x):for sub_layer in self.linears:x = sub_layer(x)return x

三、ModuleDict

ModuleDict是一个类似python字典的容器,相比于ModuleList,它的优点在于可以利用名字来调用网络层,这就避免了必须记住网络层具体元素才能调用的麻烦。
代码示例:

 class MyModule2(nn.Module):def __init__(self):super(MyModule2, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(3, 16, 5),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict({'lrelu': nn.LeakyReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return x

四、ParameterList & ParameterDict

除了Module有容器,Parameter也有容器。与ModuleList和ModuleDict类似的,Paramter也有List和Dict,使用方法一样。

class MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.params = nn.ParameterDict({'left': nn.Parameter(torch.randn(5, 10)),'right': nn.Parameter(torch.randn(5, 10))})def forward(self, x, choice):x = self.params[choice].mm(x)return x# ParameterListclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])def forward(self, x):# ParameterList can act as an iterable, or be indexed using intsfor i, p in enumerate(self.params):x = self.params[i // 2].mm(x) + p.mm(x)return x

这是专门为Pytorch中的参数(如权重和偏置)设计的容器。它确保添加到其中的参数会被正确地注册到网络中,以便进行参数管理和梯度更新。与module类似,参数容器中的参数也会被包含在网络的参数列表中,并在模型保存和加载时被正确处理。


总结

容器是pytorch框架对网络进行组织管理的实用工具,合理运用可以极大提高代码的可读性与可维护性。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 峟思投入式水位计的安全操作指南
  • AD元器件库中参数的设计
  • Java Spring Boot 项目中的密码加密与验证开发案例手册
  • FPGA技术赋能云数据中心:提高性能与效率
  • 数据治理与数据管理的区别:深入剖析与理解
  • [Go]-抢购类业务方案
  • Qt QSerialPort数据发送和接收DataComm
  • 对浏览器事件循环机制的理解
  • Redis 篇-深入了解基于 Redis 实现消息队列(比较基于 List 实现消息队列、基于 PubSub 发布订阅模型之间的区别)
  • JDBC简介与应用:Java数据库连接的核心概念和技术
  • 【Redis】Redis 典型应用 - 缓存 (Cache) 原理与策略
  • BuripSuiteProfessional 抓取HTTPS配置
  • Java实现简易计算器功能(idea)
  • day5 QT
  • 多级缓存的设计与实现
  • 《深入 React 技术栈》
  • 【知识碎片】第三方登录弹窗效果
  • Brief introduction of how to 'Call, Apply and Bind'
  • CSS魔法堂:Absolute Positioning就这个样
  • CSS实用技巧
  • gitlab-ci配置详解(一)
  • HTML5新特性总结
  • JavaSE小实践1:Java爬取斗图网站的所有表情包
  • node 版本过低
  • rc-form之最单纯情况
  • React中的“虫洞”——Context
  • Yii源码解读-服务定位器(Service Locator)
  • 第13期 DApp 榜单 :来,吃我这波安利
  • 翻译 | 老司机带你秒懂内存管理 - 第一部(共三部)
  • 复习Javascript专题(四):js中的深浅拷贝
  • 如何设计一个微型分布式架构?
  • 如何用vue打造一个移动端音乐播放器
  • 设计模式 开闭原则
  • 收藏好这篇,别再只说“数据劫持”了
  • 我有几个粽子,和一个故事
  • 小程序、APP Store 需要的 SSL 证书是个什么东西?
  • Java性能优化之JVM GC(垃圾回收机制)
  • 阿里云IoT边缘计算助力企业零改造实现远程运维 ...
  • 蚂蚁金服CTO程立:真正的技术革命才刚刚开始
  • #每日一题合集#牛客JZ23-JZ33
  • (3)Dubbo启动时qos-server can not bind localhost22222错误解决
  • (Java岗)秋招打卡!一本学历拿下美团、阿里、快手、米哈游offer
  • (zhuan) 一些RL的文献(及笔记)
  • (三)模仿学习-Action数据的模仿
  • (四)库存超卖案例实战——优化redis分布式锁
  • (一)【Jmeter】JDK及Jmeter的安装部署及简单配置
  • (原创) cocos2dx使用Curl连接网络(客户端)
  • (转)Scala的“=”符号简介
  • (转)项目管理杂谈-我所期望的新人
  • .java 9 找不到符号_java找不到符号
  • .L0CK3D来袭:如何保护您的数据免受致命攻击
  • .NET : 在VS2008中计算代码度量值
  • .NET Core 控制台程序读 appsettings.json 、注依赖、配日志、设 IOptions
  • .NET Core实战项目之CMS 第一章 入门篇-开篇及总体规划
  • .NET 常见的偏门问题