当前位置: 首页 > news >正文

pytorch中一些最基本函数和类

1.Tensor操作

Tensor是PyTorch中最基本的数据结构,类似于NumPy的数组,但可以在GPU上运行加速计算。

  示例:创建和操作Tensor

import torch# 创建一个零填充的Tensor
x = torch.zeros(3, 3)
print(x)# 加法操作
y = torch.ones(3, 3)
z = x + y
print(z)# 在GPU上创建Tensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.zeros(3, 3, device=device)
print(x)
运行结果:

2. nn.Module和自定义模型

  nn.Module是PyTorch中定义神经网络模型的基类,所有的自定义模型都应该继承自它。

示例:定义一个简单的全连接神经网络模型

import torch
import torch.nn as nn# 自定义模型类
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc = nn.Linear(10, 5)  # 线性层:输入维度为10,输出维度为5def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleNet()
print(model)
运行结果:

3. DataLoader和Dataset

 DataLoader用于批量加载数据Dataset定义了数据集的接口,自定义数据集需继承自它。

示例:加载自定义数据集

import torch
from torch.utils.data import Dataset, DataLoader# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, data, targets):self.data = dataself.targets = targetsdef __len__(self):return len(self.data)def __getitem__(self, index):x = self.data[index]y = self.targets[index]return x, y# 假设有一些数据和标签
data = torch.randn(100, 10)  # 100个样本,每个样本10维
targets = torch.randint(0, 2, (100,))  # 100个随机标签,0或1# 创建数据集实例
dataset = CustomDataset(data, targets)# 创建数据加载器
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 打印一个batch的数据
for batch in dataloader:inputs, labels = batchprint(inputs.shape, labels.shape)break
运行结果: 

4. 优化器和损失函数

   优化器用于更新模型参数以减少损失,损失函数用于计算预测值与实际值之间的差异。

示例:使用优化器和损失函数

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型(假设已定义好)
model = SimpleNet()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)# 前向传播、损失计算、反向传播和优化过程请参考前面完整示例的训练循环部分。
运行结果: 

5. nn.functional中的函数

  nn.functional提供了各种用于构建神经网络的函数,如激活函数池化操作等。

示例:使用ReLU激活函数

import torch
import torch.nn.functional as F# 创建一个Tensor
x = torch.randn(3, 3)# 使用ReLU激活函数
output = F.relu(x)
print(output)
运行结果: 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 集群架构-web服务器(接入负载均衡+数据库+会话保持redis)--15454核心配置详解
  • 华为USG6000V防火墙安全策略用户认证
  • 01.Verilog基础语法
  • js中scrollIntoView第一次不生效,第二次生效
  • Linux C++ 055-设计模式之状态模式
  • React 的生命周期方法有哪些?
  • [论文笔记]构建基于RAG聊天机器人的要素
  • pycharm2020 相比pycarm2017更新内容
  • redis安装,启动客户端、验证(redis第一次作业)
  • 深入Laravel的魔法核心:依赖注入的工作原理
  • 智慧煤矿:AI视频智能监管解决方案引领行业新变革
  • 【Java】:浅克隆和深克隆
  • Java设计模式的7个设计原则
  • [计算机基础]一、计算机组成原理
  • 在 Windows 上开发.NET MAUI 应用_1.安装开发环境
  • [微信小程序] 使用ES6特性Class后出现编译异常
  • 【前端学习】-粗谈选择器
  • CAP 一致性协议及应用解析
  • CentOS从零开始部署Nodejs项目
  • css布局,左右固定中间自适应实现
  • Flex布局到底解决了什么问题
  • Javascript基础之Array数组API
  • Promise初体验
  • Redis字符串类型内部编码剖析
  • Vim Clutch | 面向脚踏板编程……
  • Vue--数据传输
  • 前端js -- this指向总结。
  • 前端代码风格自动化系列(二)之Commitlint
  • 跳前端坑前,先看看这个!!
  • 优化 Vue 项目编译文件大小
  • 你对linux中grep命令知道多少?
  • 《码出高效》学习笔记与书中错误记录
  • 小白应该如何快速入门阿里云服务器,新手使用ECS的方法 ...
  • ​Linux Ubuntu环境下使用docker构建spark运行环境(超级详细)
  • ​一帧图像的Android之旅 :应用的首个绘制请求
  • #Datawhale AI夏令营第4期#AIGC文生图方向复盘
  • #pragma once
  • (C#)一个最简单的链表类
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (Demo分享)利用原生JavaScript-随机数-实现做一个烟花案例
  • (Java入门)抽象类,接口,内部类
  • (ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY)讲解
  • (ZT)北大教授朱青生给学生的一封信:大学,更是一个科学的保证
  • (阿里巴巴 dubbo,有数据库,可执行 )dubbo zookeeper spring demo
  • (力扣记录)235. 二叉搜索树的最近公共祖先
  • (三)Honghu Cloud云架构一定时调度平台
  • (三)mysql_MYSQL(三)
  • (转)iOS字体
  • (转)LINQ之路
  • (转)mysql使用Navicat 导出和导入数据库
  • (转)微软牛津计划介绍——屌爆了的自然数据处理解决方案(人脸/语音识别,计算机视觉与语言理解)...
  • (转载)Linux网络编程入门
  • (转载)PyTorch代码规范最佳实践和样式指南
  • **PHP分步表单提交思路(分页表单提交)
  • .bat批处理(三):变量声明、设置、拼接、截取