当前位置: 首页 > news >正文

13.DataLoader 的使用

DataLoader 的使用

  • dataset:告诉程序中数据集的位置,数据集中索引,数据集中有多少数据(想象成一叠扑克牌)
  • dataloader:加载器,将数据加载到神经网络中,每次从dataset中取数据,通过dataloader中的参数可以设置如何取数据(想象成抓的一组牌)

torch.utils.data

参数介绍

参数如下(大部分有默认值,实际中只需要设置少量的参数即可):

  • dataset:只有dataset没有默认值,只需要将之前自定义的dataset实例化,再放到dataloader中即可
  • batch_size:每次抓牌抓几张
  • shuffle:打乱与否,值为True的话两次打牌时牌的顺序是不一样。默认为False,但一般用True
  • num_workers:加载数据时采用单个进程还是多个进程,多进程的话速度相对较快,默认为0(主进程加载)。Windows系统下该值>0会有问题(报错提示:BrokenPipeError)
  • drop_last:100张牌每次取3张,最后会余下1张,这时剩下的这张牌是舍去还是不舍去。值为True代表舍去这张牌、不取出,False代表要取出该张牌

image-20240718181740205

示例
import torchvision
from torch.utils.data import DataLoader# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
#加载测试数据集,batch_size=4即每次取4个数据集打包
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)

image-20240720154545533

输出结果:

torch.Size([3, 32, 32])   #三通道,32×32大小
3   #类别为3

image-20240720154614241

测试数据集CIFAR10中getitem返回的数据类型为img,target

image-20240720154909056

image-20240720155510287

dataset

__getitem()__:return img,target

dataloader(batch_size=4):从dataset中取4个数据

img0,target0 = dataset[0]
img1,target1 = dataset[1]
img2,target2 = dataset[2]
img3,target3 = dataset[3]

把 img 0-3 进行打包,记为imgs;target 0-3 进行打包,记为targets;作为dataloader中的返回

for data in test_loader:imgs,targets = dataprint(imgs.shape)print(targets)

image-20240720155804688

输出:

torch.Size([4, 3, 32, 32])   #4张图片,三通道,32×32
tensor([1, 1, 7, 3])  #4个target进行一个打包

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

数据是随机取的(断点debug一下,可以看到采样器sampler是随机采样的),所以两次的 target 0 并不一样

batch_size

# 用上节课torchvision提供的自定义的数据集
# CIFAR10原本是PIL Image,需要转换成tensorimport torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())# 加载测试集
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
#batch_size=4,意味着每次从test_data中取4个数据进行打包writer = SummaryWriter("dataloader")
step=0
for data in test_loader:imgs,targets = data  #imgs是tensor数据类型writer.add_images("test_data",imgs,step)step=step+1writer.close()

image-20240722014509468

运行后在 terminal 里输入:

 tensorboard --logdir="dataloader"

运行结果如图,滑动滑块即是每一次取数据时的batch_size张图片:

image-20240722014641511

由于 drop_last 设置为 False,所以最后16张图片(没有凑齐64张)显示如下:

image-20240722014707142

drop_last

若将 drop_last 设置为 True,最后16张图片(step 156)会被舍去,结果如图:

image-20240722014853921

shuffle

shuffle的作用:一个 for data in test_loader 循环,就意味着打完一轮牌(抓完一轮数据),在下一轮再进行抓取时,第二次数据是否与第一次数据一样。值为True的话,会重新洗牌(一般都设置为True)

shuffle为False的话两轮取的图片是一样的

在外面再套一层 for epoch in range(2) 的循环来验证一下

# shuffle为True
for epoch in range(2):step=0for data in test_loader:imgs,targets = data  #imgs是tensor数据类型writer.add_images("Epoch:{}".format(epoch),imgs,step)step=step+1

image-20240722015501306

shuffle为False结果如下:

可以看出两次 step 155 的图片一样

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

shuffle为True结果如下:

可以看出即使是同样的 step 155,两轮抓取的图片不一样

image-20240722015848282
出两次 step 155 的图片一样

[外链图片转存中…(img-F8G96Zxa-1724861448845)]

shuffle为True结果如下:

可以看出即使是同样的 step 155,两轮抓取的图片不一样

[外链图片转存中…(img-Aru5xvXY-1724861448846)]

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 三级_网络技术_52_应用题
  • 深度学习中Embedding的理解
  • pytorch的继承方法
  • 探索数字沙龙——文本描述生成数字人3D发型的专业工具
  • 数据结构与算法再探(二)高精度计算
  • ActiveMQ指南
  • SpringBoot项目目录介绍(SpringBoot学2)
  • pycharm中opencv-python和opencv-contrib安装
  • 【XR】SDK的接口规划与设计
  • K8S对接Ceph分布式存储
  • apache服务器的配置(服务名httpd,端口80 , 443)
  • Ubuntu安装交叉编译工具链(gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu)
  • 中文乱码解决方案
  • R语言论文插图模板第8期—特征渲染的散点图
  • YoloV8改进策略:主干网络改进|CAS-ViT在YoloV8中的创新应用与显著性能提升
  • 【Leetcode】104. 二叉树的最大深度
  • 【前端学习】-粗谈选择器
  • AHK 中 = 和 == 等比较运算符的用法
  • Angular 2 DI - IoC DI - 1
  • Docker 笔记(1):介绍、镜像、容器及其基本操作
  • JavaScript服务器推送技术之 WebSocket
  • Laravel 实践之路: 数据库迁移与数据填充
  • Logstash 参考指南(目录)
  • maven工程打包jar以及java jar命令的classpath使用
  • node 版本过低
  • Python连接Oracle
  • spark本地环境的搭建到运行第一个spark程序
  • SpiderData 2019年2月23日 DApp数据排行榜
  • Travix是如何部署应用程序到Kubernetes上的
  • 世界上最简单的无等待算法(getAndIncrement)
  • 适配mpvue平台的的微信小程序日历组件mpvue-calendar
  • 数据库写操作弃用“SELECT ... FOR UPDATE”解决方案
  • 突破自己的技术思维
  • 文本多行溢出显示...之最后一行不到行尾的解决
  • 学习笔记:对象,原型和继承(1)
  • No resource identifier found for attribute,RxJava之zip操作符
  • ​ArcGIS Pro 如何批量删除字段
  • ​Spring Boot 分片上传文件
  • #{}和${}的区别?
  • #HarmonyOS:基础语法
  • $.ajax中的eval及dataType
  • (145)光线追踪距离场柔和阴影
  • (MonoGame从入门到放弃-1) MonoGame环境搭建
  • (翻译)terry crowley: 写给程序员
  • (附源码)apringboot计算机专业大学生就业指南 毕业设计061355
  • (附源码)springboot 个人网页的网站 毕业设计031623
  • (附源码)计算机毕业设计SSM疫情社区管理系统
  • (三)docker:Dockerfile构建容器运行jar包
  • (四)js前端开发中设计模式之工厂方法模式
  • (四)Linux Shell编程——输入输出重定向
  • (完整代码)R语言中利用SVM-RFE机器学习算法筛选关键因子
  • (原創) 物件導向與老子思想 (OO)
  • (转)Unity3DUnity3D在android下调试
  • .net core 使用js,.net core 使用javascript,在.net core项目中怎么使用javascript
  • .NET 设计模式—简单工厂(Simple Factory Pattern)