当前位置: 首页 > news >正文

PyTorch中DataLoader及其与enumerate()用法介绍

文章目录

    • DataLoader,何许类?
      • Map-style datasets
      • iterable-style datasets
      • Data loading order and sampling
      • Loading Batched and Non-Batched Data
      • Single- and Multi-process Data Loading
      • Memory Pinning
    • DataLoader、图片、张量关系
    • 批处理样本操作

DataLoader,何许类?

DataLoader隶属PyTorch中torch.utils.data下的一个类,官方文档如下介绍:

At the heart of PyTorch data loading utility is the torch.utils.data.DataLoader class. It represents a Python iterable over a dataset, with support for

  • map-style and iterable-style datasets,
  • customizing data loading order,
  • automatic batching,
  • single- and multi-process data loading,
  • automatic memory pinning.

Map-style datasets

map一词除了我们熟知的地图外,其实还有映射的意思。这一应用在我之前写过一篇基于参考点的非支配遗传算法-NSGA-III(一)中就提及过“映射”关系,大家可以自行去查看原为对于“映射”关系的英文描述。

在DataLoader中映射关系是表示的索引到数据之间的关系,其定义:实现
_ getitem_ () and len() protocol,且将data sample与indices/keys(可能是非整数)映射起来的dataset。例如dataset[idx]可读得第idx张图片和对应的label。

需要说明的是,任何继承torch.utils.data.Data类子类军需要重载_getitem_()及_len_()两个函数,且子类在init函数产生的数据路径,将作为DataLoader参数DataSets的实参。两者之间的关系我们将在下文代码中介绍。

iterable-style datasets

定义:为IterableDataset子类的一个实例,实现了__iter()__ protocol,并表示对data sample的迭代。这类dataset适用于对数据的random read开销较大或不合适时,且batch size取决于数据时。例如iter(dataset),可以返回从dataset或远程服务器等读到的数据流。

Data loading order and sampling

For iterable-style datasets, data loading order is entirely controlled by the user-defined iterable. This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).
也就是说可以很容易的实现批处理,是通过块来读数据的

The rest of this section concerns the case with map-style datasets. torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.
torch.utils.data.Sampler类用于指定数据加载中使用的索引/键的顺序。它们代表数据集索引上的可迭代对象。例如,在SGD常见情况下,Sampler可以随机排列一列索引,一次生成每个索引,或者为小批量SGD生成少量索引。

A sequential or shuffled sampler will be automatically constructed based on the shuffle argument to a DataLoader. Alternatively, users may use the sampler argument to specify a custom Sampler object that at each time yields the next index/key to fetch.

DataLoader 的 shuffle 参数,将自动构造顺序或随机排序的采样器。
可以一次生成批量索引列表的自定义采样器作为batch_sampler参数。也可以通过batch_size和drop_last参数启用自动批处理。iterable-style datasets 不能和 sample/ batch_sample 一起使用, 因为iterable-style datasets 没有 index 和 key的概念。

Loading Batched and Non-Batched Data

DataLoader supports automatically collating individual fetched data samples into batches via arguments batch_size, drop_last, batch_sampler, and collate_fn (which has a default function).

Automatic batching (default)
This is the most common case, and corresponds to fetching a minibatch of data and collating them into batched samples, i.e., containing Tensors with one dimension being the batch dimension (usually the first).
包含一个批处理维,用来表示样本批处理后的大小,批处理后的样本称作“批处理样本”。 如一组样本又1600个,每个批处理包含八个样本,每个样本是大小为480*640的RGB图,则会生成200*3*480*640个张量数据。

When batch_size (default 1) is not None, the data loader yields batched samples instead of individual samples. batch_size and drop_last arguments are used to specify how the data loader obtains batches of dataset keys. For map-style datasets, users can alternatively specify batch_sampler, which yields a list of keys at a time.
批处理的个数由每个批处理大小及drop_last(最后不够一个批处理的样本处理过程)决定,每个批处理样本索引可以是任意的,这个可以通过shuttle来决定。

Single- and Multi-process Data Loading

默认情况下,DataLoader使用单进程数据加载。在Python进程中,全局解释器锁(GIL)防止跨线程真正地完全并行化Python代码。为了避免在加载数据时阻塞计算代码,PyTorch提供了一个简单的开关,只需将参数 num_workers 设置为正整数即可执行多进程数据加载。

Memory Pinning

对于数据加载,将pin_memory = True传递给DataLoader将自动将获取的数据张量放入固定内存中,从而能够更快地将数据传输到支持CUDA的GPU。

DataLoader、图片、张量关系

为更好的解释四者之间的关系,我这里直接附上代码,通过注释和说明方式来解释。

def train(config):
    # 将参数和缓冲区转移到GPU
    dehaze_net = net.dehaze_net().cuda()
    # Applies fn recursively to every submodule (as returned by .children()) as well as self.
    # Typical use includes initializing the parameters of a model (see also torch.nn.init).
    # torch.nn.Module.apply(fn): fn (Module -> None) – function to be applied to each submodule
    dehaze_net.apply(weights_init)

    # train_dataset and val_dataset目的是获取训练集和验证集数据的文件名,除了个数不一样外,两者init函数所获得的属性一致
    train_dataset = dataloader.dehazing_loader(config.orig_images_path,
                                               config.hazy_images_path)
    # mode覆盖
    val_dataset = dataloader.dehazing_loader(config.orig_images_path,
                                             config.hazy_images_path, mode="val")

    # 返回两个DataLoader实例对象集,个数为 (the number of dataset)/batch_size,会调用len()函数
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True,
                                               num_workers=config.num_workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.val_batch_size, shuffle=True,
                                             num_workers=config.num_workers, pin_memory=True)

    criterion = nn.MSELoss().cuda()
    # torch.nn.Module.parameters()- Returns an iterator over module parameters.
    #To construct an Optimizer you have to give it an iterable containing the parameters (all should be Variable s) to optimize. Then, you can specify optimizer-specific options such as the learning rate, weight decay, etc.
    optimizer = torch.optim.Adam(dehaze_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    # Sets the module in training mode.
    dehaze_net.train()

说明:1、dehazing_loader()函数是为了获取训练集和测试集数据路径的,该类继承了Data类; def init(self, iterable, start=0): # known special case of enumerate.init
“”" Initialize self. See help(type(self)) for accurate signature. “”"
pass
2、获取后的数据,需要借助DataLoader类来实现数据的批处理及张量的表示(前边我们已经说了,任何继承Data类的子类均将重载_getitem_()及_len(),而_getitem()调用就是在DataLoader类调用时被调用的)

批处理样本操作

我们在获得了批处理样本后(如train_loader),如何实现对于每个批处理样本进行操作呢,这里我们可通过enumerate()来实现。我们可以在pycharm中查看enumerate()函数定义:

 builtins.py
 
   def __init__(self, iterable, start=0): # known special case of enumerate.__init__
        """ Initialize self.  See help(type(self)) for accurate signature. """
        pass

self指代的就是数据对象,iterable代表数据的个数,从0开始;返回值有两个:一个是序号,一个是数据。
那我们的批处理样本数据可以通过以下代码实现操作

        for iteration, (img_orig, img_haze) in enumerate(train_loader):
            img_orig = img_orig.cuda()
            img_haze = img_haze.cuda()

说明:1、iteration也就是上边的序号,指代批处理的索引;
2、(img_orig, img_haze)表示数据,这里我们采用了list形式来保存数据元素。若批处理大小设置为8,则img_orig及img_haze均为8*3*480*640的张量数据

最后附上各Variables之间的关系图
在这里插入图片描述
从上边的关系图中也可以看到train_dataset及train_loader最终存储的是数据路径,即data_list。

相关文章:

  • mac 使用 PyQt5 和 py_designer 搭建窗体
  • 嵌入式SQL开发
  • 对于HTTP协议,什么是长连接和短连接?
  • ReentrantLock读写锁
  • leetcode 1680. Concatenation of Consecutive Binary Numbers(连接连续的二进制数)
  • Python数据分析之时间序列的处理
  • 【PHP】如何搭建服务器环境 原生篇 | Ubuntu 18.04 + PHP8.1 + MySQL5.7 + Nginx 1.4
  • 【c语言】数据在内存中的存储
  • 数据结构考试必须要掌握的重点知识
  • 进程管理4——进程优先级
  • 外网访问内网80端口【内网穿透】
  • Android性能优化技术,在大厂中为何这么看重?进大厂必学好
  • 基于自建数据集【海底生物检测】使用YOLOv5-v6.1/2版本构建目标检测模型超详细教程
  • 水平分表之基因法
  • Gorm笔记
  • 【跃迁之路】【699天】程序员高效学习方法论探索系列(实验阶段456-2019.1.19)...
  • 8年软件测试工程师感悟——写给还在迷茫中的朋友
  • FastReport在线报表设计器工作原理
  • Javascript弹出层-初探
  • JavaScript新鲜事·第5期
  • Java读取Properties文件的六种方法
  • KMP算法及优化
  • SpringBoot几种定时任务的实现方式
  • ucore操作系统实验笔记 - 重新理解中断
  • Vue UI框架库开发介绍
  • 测试开发系类之接口自动化测试
  • 分享一份非常强势的Android面试题
  • 浮现式设计
  • 记一次用 NodeJs 实现模拟登录的思路
  • 如何设计一个比特币钱包服务
  • 深入浏览器事件循环的本质
  • 使用API自动生成工具优化前端工作流
  • 世界编程语言排行榜2008年06月(ActionScript 挺进20强)
  • 手写双向链表LinkedList的几个常用功能
  • 腾讯视频格式如何转换成mp4 将下载的qlv文件转换成mp4的方法
  • 提醒我喝水chrome插件开发指南
  • 自制字幕遮挡器
  • 1.Ext JS 建立web开发工程
  • ​Kaggle X光肺炎检测比赛第二名方案解析 | CVPR 2020 Workshop
  • #ifdef 的技巧用法
  • $jQuery 重写Alert样式方法
  • (10)工业界推荐系统-小红书推荐场景及内部实践【排序模型的特征】
  • (4)事件处理——(7)简单事件(Simple events)
  • (solr系列:一)使用tomcat部署solr服务
  • (每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理 第13章 项目资源管理(七)
  • (南京观海微电子)——COF介绍
  • (四)TensorRT | 基于 GPU 端的 Python 推理
  • (源码版)2024美国大学生数学建模E题财产保险的可持续模型详解思路+具体代码季节性时序预测SARIMA天气预测建模
  • (转)Sql Server 保留几位小数的两种做法
  • ***linux下安装xampp,XAMPP目录结构(阿里云安装xampp)
  • ./indexer: error while loading shared libraries: libmysqlclient.so.18: cannot open shared object fil
  • .bat批处理(十):从路径字符串中截取盘符、文件名、后缀名等信息
  • .equal()和==的区别 怎样判断字符串为空问题: Illegal invoke-super to void nio.file.AccessDeniedException
  • .FileZilla的使用和主动模式被动模式介绍
  • .Net CoreRabbitMQ消息存储可靠机制