当前位置: 首页 > news >正文

【PyTorch】(二)加载数据集

文章目录

  • 1. 创建数据集
    • 1.1. 直接继承Dataset类
    • 1.2. 使用TensorDataset类
  • 2. 加载数据集
  • 3. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)

3. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()

相关文章:

  • 如何使用内网穿透实现无公网ip环境访问VScode远程开发
  • pip安装、更新、卸载
  • CTA-GAN:基于生成对抗性网络的主动脉和颈动脉非集中CT血管造影 CT到增强CT的合成技术
  • Java中xml映射文件是干什么的
  • 开闭原则:提高扩展性的小技巧
  • 计算机视觉面试题-03
  • LeetCode算法题解(动态规划,背包问题)|LeetCode416. 分割等和子集
  • 【Spring之AOP底层源码解析】
  • vue 表格虚拟滚动
  • Web3之L2 ZK-Rollup 方案-StarkNet
  • 【Lustre相关】功能实践-01-Lustre集群部署配置
  • 鸿蒙4.0开发笔记之ArkTS装饰器语法基础@Extend扩展组件样式与stateStyles多态样式(十一)
  • FastDFS+Nginx - 本地搭建文件服务器同时实现在外远程访问「内网穿透」
  • 数据结构学习笔记——二叉树的遍历和链式存储代码实现二叉树
  • qt 容器QStringList的常见使用
  • CentOS学习笔记 - 12. Nginx搭建Centos7.5远程repo
  • ES2017异步函数现已正式可用
  • export和import的用法总结
  • interface和setter,getter
  • JS+CSS实现数字滚动
  • maya建模与骨骼动画快速实现人工鱼
  • ReactNative开发常用的三方模块
  • Spring思维导图,让Spring不再难懂(mvc篇)
  • vue+element后台管理系统,从后端获取路由表,并正常渲染
  • 从地狱到天堂,Node 回调向 async/await 转变
  • 动手做个聊天室,前端工程师百无聊赖的人生
  • 给新手的新浪微博 SDK 集成教程【一】
  • 马上搞懂 GeoJSON
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 微信小程序:实现悬浮返回和分享按钮
  • 正则表达式小结
  • ​香农与信息论三大定律
  • ​一、什么是射频识别?二、射频识别系统组成及工作原理三、射频识别系统分类四、RFID与物联网​
  • #etcd#安装时出错
  • #ifdef 的技巧用法
  • $con= MySQL有关填空题_2015年计算机二级考试《MySQL》提高练习题(10)
  • (6)设计一个TimeMap
  • (poj1.3.2)1791(构造法模拟)
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (附源码)springboot学生选课系统 毕业设计 612555
  • (附源码)ssm经济信息门户网站 毕业设计 141634
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (附源码)计算机毕业设计大学生兼职系统
  • (转)ObjectiveC 深浅拷贝学习
  • (转)编辑寄语:因为爱心,所以美丽
  • *++p:p先自+,然后*p,最终为3 ++*p:先*p,即arr[0]=1,然后再++,最终为2 *p++:值为arr[0],即1,该语句执行完毕后,p指向arr[1]
  • .apk文件,IIS不支持下载解决
  • .net 4.0 A potentially dangerous Request.Form value was detected from the client 的解决方案
  • .Net 代码性能 - (1)
  • .net 写了一个支持重试、熔断和超时策略的 HttpClient 实例池
  • .Net环境下的缓存技术介绍
  • .net最好用的JSON类Newtonsoft.Json获取多级数据SelectToken
  • // an array of int
  • @cacheable 是否缓存成功_Spring Cache缓存注解
  • @cacheable 是否缓存成功_让我们来学习学习SpringCache分布式缓存,为什么用?