当前位置: 首页 > news >正文

详细介绍pytorch重要的API

文章目录

      • 1. Tensor(张量)
        • 创建 Tensor
        • 操作 Tensor
      • 2. Autograd(自动求导)
      • 3. nn.Module(神经网络模块)
      • 4. Optimizer(优化器)
      • 5. Loss Function(损失函数)
      • 6. DataLoader(数据加载器)
      • 7. torchvision(视觉工具包)
        • 数据集
        • 模型
        • 图像转换
      • 8. torch.nn.functional(函数式接口)
      • 9. torch.utils.data(数据工具)
        • Dataset
        • DataLoader
      • 10. torch.cuda(GPU 支持)


1. Tensor(张量)

Tensor 是 PyTorch 中最基本的数据结构,提供了多种创建和操作张量的函数。

创建 Tensor
import torch# 创建一个 Tensor
x = torch.tensor([1.0, 2.0, 3.0])# 从 NumPy 数组创建 Tensor
import numpy as np
np_array = np.array([1.0, 2.0, 3.0])
x = torch.from_numpy(np_array)# 创建全零或全一 Tensor
x = torch.zeros(3, 4)
x = torch.ones(3, 4)# 创建随机 Tensor
x = torch.randn(3, 4)
操作 Tensor
# 基本操作
x = torch.tensor([1.0, 2.0, 3.0])
y = x + 2
z = x * y# 索引和切片
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x[0, 1])  # 输出 2
print(x[:, 1])  # 输出 tensor([2, 5])

2. Autograd(自动求导)

Autograd 提供了自动求导功能,可以跟踪 Tensor 上的操作并计算梯度。

# 创建一个需要求导的 Tensor
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(x.grad)  # 输出梯度

3. nn.Module(神经网络模块)

nn.Module 是构建神经网络的基础类,提供了定义网络层和前向传播的方法。

import torch.nn as nn
import torch.nn.functional as Fclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNN()

4. Optimizer(优化器)

优化器用于更新神经网络的参数,提供了多种优化算法。

import torch.optim as optimmodel = SimpleNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 优化步骤
optimizer.zero_grad()  # 清空梯度
loss = compute_loss(model, data)  # 计算损失
loss.backward()  # 反向传播
optimizer.step()  # 更新参数

5. Loss Function(损失函数)

损失函数用于衡量模型预测值与真实值之间的差异,提供了多种损失函数。

criterion = nn.MSELoss()# 计算损失
output = model(data)
loss = criterion(output, target)

6. DataLoader(数据加载器)

DataLoader 用于加载数据集,并提供批量加载、数据打乱等功能。

from torch.utils.data import DataLoader
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 迭代数据
for images, labels in train_loader:# 训练代码pass

7. torchvision(视觉工具包)

torchvision 提供了常用的数据集、模型架构和图像转换工具。

数据集
from torchvision import datasetstrain_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
模型
from torchvision import models# 加载预训练的 ResNet-50 模型
model = models.resnet50(pretrained=True)
图像转换
from torchvision import transformstransform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

8. torch.nn.functional(函数式接口)

torch.nn.functional 提供了一些常用的函数式操作,如激活函数、损失函数等。

import torch.nn.functional as Fx = torch.randn(10)
y = F.relu(x)

9. torch.utils.data(数据工具)

torch.utils.data 提供了数据加载和预处理的工具,包括 DatasetDataLoader

Dataset
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, targets):self.data = dataself.targets = targetsdef __getitem__(self, index):return self.data[index], self.targets[index]def __len__(self):return len(self.data)
DataLoader
from torch.utils.data import DataLoaderdata = torch.randn(100, 10)
targets = torch.randint(0, 2, (100,))
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

10. torch.cuda(GPU 支持)

torch.cuda 提供了 GPU 计算的支持,包括设备管理、内存管理等。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 将数据移动到 GPU
images, labels = images.to(device), labels.to(device)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 靠谱是性价比最高的社交名片:一个靠谱的人往往有这4种品质!
  • 算法的学习笔记—二叉树的镜像(牛客JZ27)
  • Spring 中ConfigurableBeanFactory
  • Redis的热key以及Big(大)key是什么?如何解决Redis的热key以及Big(大)key问题?
  • arcgis打开不同tif格式编码的栅格数据
  • 【卡码网Python基础课 21.图形的面积】
  • 高速信号的眼图、加重、均衡
  • Spire.PDF for .NET【文档操作】演示:检测 PDF 文件是否为 Portfolio
  • Airtest 的使用
  • 对比state和props的区别?
  • C语言——操作符详解
  • C++ STL sort_heap 用法
  • XSS---DOM破坏靶场复现
  • mybatisplus多数据源中关于不同类型的(mysql,oracle)数据库分页问题解决
  • 【Angular18】封装自定义组件
  • 【跃迁之路】【444天】程序员高效学习方法论探索系列(实验阶段201-2018.04.25)...
  • C++11: atomic 头文件
  • Docker下部署自己的LNMP工作环境
  • es6(二):字符串的扩展
  • GitUp, 你不可错过的秀外慧中的git工具
  • IE报vuex requires a Promise polyfill in this browser问题解决
  • java8 Stream Pipelines 浅析
  • JavaScript 基础知识 - 入门篇(一)
  • Linux各目录及每个目录的详细介绍
  • MySQL-事务管理(基础)
  • Ruby 2.x 源代码分析:扩展 概述
  • spring-boot List转Page
  • SSH 免密登录
  • 读懂package.json -- 依赖管理
  • 自动记录MySQL慢查询快照脚本
  • 7行Python代码的人脸识别
  • Semaphore
  • 如何用纯 CSS 创作一个货车 loader
  • 如何在 Intellij IDEA 更高效地将应用部署到容器服务 Kubernetes ...
  • 数据可视化之下发图实践
  • ​LeetCode解法汇总1410. HTML 实体解析器
  • # Redis 入门到精通(八)-- 服务器配置-redis.conf配置与高级数据类型
  • #define MODIFY_REG(REG, CLEARMASK, SETMASK)
  • #vue3 实现前端下载excel文件模板功能
  • #传输# #传输数据判断#
  • #在线报价接单​再坚持一下 明天是真的周六.出现货 实单来谈
  • (14)学习笔记:动手深度学习(Pytorch神经网络基础)
  • (附源码)计算机毕业设计ssm电影分享网站
  • (六)Hibernate的二级缓存
  • (四)Tiki-taka算法(TTA)求解无人机三维路径规划研究(MATLAB)
  • (原創) 如何使用ISO C++讀寫BMP圖檔? (C/C++) (Image Processing)
  • (转)Windows2003安全设置/维护
  • .net core Swagger 过滤部分Api
  • .Net CoreRabbitMQ消息存储可靠机制
  • .Net 基于.Net8开发的一个Asp.Net Core Webapi小型易用框架
  • .NET委托:一个关于C#的睡前故事
  • .net下的富文本编辑器FCKeditor的配置方法
  • .net知识和学习方法系列(二十一)CLR-枚举
  • /etc/sudoer文件配置简析
  • @RequestMapping 和 @GetMapping等子注解的区别及其用法