当前位置: 首页 > news >正文

PyTorch 的 torch.nn 模块学习

torch.nn 是 PyTorch 中专门用于构建和训练神经网络的模块。它的整体架构分为几个主要部分,每部分的原理、要点和使用场景如下:

1. nn.Module

  • 原理和要点nn.Module 是所有神经网络组件的基类。任何神经网络模型都应该继承 nn.Module,并实现其 forward 方法。
  • 使用场景:用于定义和管理神经网络模型,包括层、损失函数和自定义的前向传播逻辑。
  • 主要 API 和使用场景
    __init__: 初始化模型参数。
    forward: 定义前向传播逻辑。
    parameters: 返回模型的所有参数。
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 1)def forward(self, x):return self.linear(x)model = MyModel()
print(model)

2. Layers(层)

  • 原理和要点:层是神经网络的基本构建块,包括全连接层、卷积层、池化层等。每种层执行特定类型的操作,并包含可学习的参数。
  • 使用场景:用于构建神经网络的各个组成部分,如特征提取、降维等。
2.1 nn.Linear(全连接层)
linear = nn.Linear(10, 5)
input = torch.randn(1, 10)
output = linear(input)
print(output)
2.2 nn.Conv2d(二维卷积层)
conv = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
input = torch.randn(1, 1, 5, 5)
output = conv(input)
print(output)
2.3 nn.MaxPool2d(二维最大池化层)
maxpool = nn.MaxPool2d(kernel_size=2)
input = torch.randn(1, 1, 4, 4)
output = maxpool(input)
print(output)

3. Loss Functions(损失函数)

  • 原理和要点:损失函数用于衡量模型预测与真实值之间的差异,指导模型优化过程。
  • 使用场景:用于计算训练过程中需要最小化的误差。
3.1 nn.MSELoss(均方误差损失)
mse_loss = nn.MSELoss()
input = torch.randn(3, 5)
target = torch.randn(3, 5)
loss = mse_loss(input, target)
print(loss)
3.2 nn.CrossEntropyLoss(交叉熵损失)
cross_entropy_loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = cross_entropy_loss(input, target)
print(loss)

4. Optimizers(优化器)

  • 原理和要点:优化器用于调整模型参数,以最小化损失函数。
  • 使用场景:用于训练模型,通过反向传播更新参数。
4.1 torch.optim.SGD(随机梯度下降)
import torch.optim as optimmodel = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()# Training loop
for epoch in range(100):optimizer.zero_grad()output = model(torch.randn(1, 10))loss = criterion(output, torch.randn(1, 1))loss.backward()optimizer.step()
4.2 torch.optim.Adam(自适应矩估计)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(100):optimizer.zero_grad()output = model(torch.randn(1, 10))loss = criterion(output, torch.randn(1, 1))loss.backward()optimizer.step()

5. Activation Functions(激活函数)

  • 原理和要点:激活函数引入非线性,使模型能够拟合复杂的函数。
  • 使用场景:用于激活输入,增加模型表达能力。
5.1 nn.ReLU(修正线性单元)
relu = nn.ReLU()
input = torch.randn(2)
output = relu(input)
print(output)

6. Normalization Layers(归一化层)

  • 原理和要点:归一化层用于标准化输入,改善训练的稳定性和速度。
  • 使用场景:用于标准化激活值,防止梯度爆炸或消失。
6.1 nn.BatchNorm2d(二维批量归一化)
batch_norm = nn.BatchNorm2d(3)
input = torch.randn(1, 3, 5, 5)
output = batch_norm(input)
print(output)

7. Dropout Layers(丢弃层)

  • 原理和要点:Dropout 层通过在训练过程中随机丢弃一部分神经元来防止过拟合。
  • 使用场景:用于防止模型过拟合,增加模型的泛化能力。
7.1 nn.Dropout
dropout = nn.Dropout(p=0.5)
input = torch.randn(2, 3)
output = dropout(input)
print(output)

8. Container Modules(容器模块)

  • 原理和要点:容器模块用于组合多个层,构建复杂的神经网络结构。
  • 使用场景:用于组合多个层,形成更复杂的网络结构。
8.1 nn.Sequential(顺序容器)
model = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 5)
)
input = torch.randn(1, 10)
output = model(input)
print(output)
8.2 nn.ModuleList(模块列表)
layers = nn.ModuleList([nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 5)
])input = torch.randn(1, 10)
for layer in layers:input = layer(input)
print(input)

9. Functional API (torch.nn.functional)

  • 原理和要点:包含大量用于深度学习的无状态函数,这些函数通常是操作层的底层实现。
  • 使用场景:用于在前向传播中灵活调用函数。
9.1 F.relu(ReLU 激活函数)
import torch.nn.functional as Finput = torch.randn(2)
output = F.relu(input)
print(output)
9.2 F.cross_entropy(交叉熵损失函数)
input = torch.randn(3, 5)
target = torch.tensor([1, 0, 4])
loss = F.cross_entropy(input, target)
print(loss)
9.3 F.conv2d(二维卷积)
input = torch.randn(1, 1, 5, 5)
weight = torch.randn(3, 1, 3, 3)  # Manually defined weights
output = F.conv2d(input, weight)
print(output)

10. Parameter (torch.nn.Parameter)

  • 原理和要点torch.nn.Parametertorch.Tensor 的一种特殊子类,用于表示模型的可学习参数。它们在 nn.Module 中会自动注册为参数。
  • 使用场景:用于定义模型中的可学习参数。
示例代码:
class MyModelWithParam(nn.Module):def __init__(self):super(MyModelWithParam, self).__init__()self.my_param = nn.Parameter(torch.randn(10, 10))def forward(self, x):return x @ self.my_parammodel = MyModelWithParam()
input = torch.randn(1, 10)
output = model(input)
print(output)# 查看模型参数
for name, param in model.named_parameters():print(name, param.size())

综合示例

下面是一个结合上述各个部分的综合示例:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimclass MyComplexModel(nn.Module):def __init__(self):super(MyComplexModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.bn1 = nn.BatchNorm2d(32)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.bn2 = nn.BatchNorm2d(64)self.dropout = nn.Dropout(0.25)self.fc1 = nn.Linear(64*12*12, 128)self.fc2 = nn.Linear(128, 10)self.custom_param = nn.Parameter(torch.randn(128, 128))def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.max_pool2d(x, 2)x = F.relu(self.bn2(self.conv2(x)))x = F.max_pool2d(x, 2)x = self.dropout(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = x @ self.custom_paramx = self.fc2(x)return F.log_softmax(x, dim=1)model = MyComplexModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(10):optimizer.zero_grad()input = torch.randn(64, 1, 28, 28)target = torch.randint(0, 10, (64,))output = model(input)loss = criterion(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

通过以上示例,可以更清晰地理解 torch.nn 模块的整体架构、原理、要点及其具体使用场景。

相关文章:

  • 正则表达式----IP地址合法性判断
  • 啵啵啵啵啵啵啵啵啵啵啵啵啵啵啵
  • Java面试——中间件
  • 嵌入式Linux系统编程 — 2.1 标准I/O库简介
  • cs与msf权限传递
  • 最大矩形问题
  • 如何给 MySQL 表和列授予权限?(官方版)
  • HBuilderX编写APP一、获取token
  • Polar Web【简单】upload1
  • 【Meetup】探索Apache SeaTunnel的二次开发与实战案例
  • 数据结构初阶 堆(一)
  • PostgreSQL的视图pg_class
  • 【STM32】STM32F103C6T6标准外设库
  • 【傅里叶变换】 关于 Matlab ifft(Y,n) 在C#中实现遇到的问题
  • YOLOv8---seg实例分割(制作数据集,训练模型,预测结果)
  • 分享一款快速APP功能测试工具
  • 【翻译】Mashape是如何管理15000个API和微服务的(三)
  • CSS实用技巧干货
  • ES6 ...操作符
  • jquery ajax学习笔记
  • Laravel5.4 Queues队列学习
  • Linux编程学习笔记 | Linux多线程学习[2] - 线程的同步
  • Linux中的硬链接与软链接
  • Quartz初级教程
  • RedisSerializer之JdkSerializationRedisSerializer分析
  • springboot_database项目介绍
  • SQL 难点解决:记录的引用
  • 成为一名优秀的Developer的书单
  • 基于webpack 的 vue 多页架构
  • 如何用Ubuntu和Xen来设置Kubernetes?
  • 什么软件可以提取视频中的音频制作成手机铃声
  • 使用Maven插件构建SpringBoot项目,生成Docker镜像push到DockerHub上
  • 数据结构java版之冒泡排序及优化
  • 因为阿里,他们成了“杭漂”
  • PostgreSQL之连接数修改
  • ​​快速排序(四)——挖坑法,前后指针法与非递归
  • ​LeetCode解法汇总307. 区域和检索 - 数组可修改
  • ​埃文科技受邀出席2024 “数据要素×”生态大会​
  • #NOIP 2014# day.1 生活大爆炸版 石头剪刀布
  • #我与Java虚拟机的故事#连载15:完整阅读的第一本技术书籍
  • (1)Map集合 (2)异常机制 (3)File类 (4)I/O流
  • (k8s中)docker netty OOM问题记录
  • (poj1.2.1)1970(筛选法模拟)
  • (阿里云万网)-域名注册购买实名流程
  • (二)linux使用docker容器运行mysql
  • (二)学习JVM —— 垃圾回收机制
  • (附源码)ssm教材管理系统 毕业设计 011229
  • (附源码)计算机毕业设计高校学生选课系统
  • (每日持续更新)jdk api之StringBufferInputStream基础、应用、实战
  • (译) 理解 Elixir 中的宏 Macro, 第四部分:深入化
  • (转)ObjectiveC 深浅拷贝学习
  • (转)平衡树
  • (转)四层和七层负载均衡的区别
  • .java 9 找不到符号_java找不到符号
  • .NET HttpWebRequest、WebClient、HttpClient