深度学习-4-PyTorch中的数据加载器Dataset和DataLoader
参考Pytorch的torch.utils.data中Dataset以及DataLoader等详解
在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?
1 模块torch.utils.data
torch.utils.data是PyTorch提供的一个模块,用于处理和加载数据。
该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。
下面是torch.utils.data模块中一些常用的类和函数:
(1)Dataset: 定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。Dataset 类提供了两个必须实现的方法【下划线getitem下划线】 用于访问单个样本,【下划线len下划线】用于返回数据集的大小。
(2)TensorDataset: 继承自Dataset类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。
(3)DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。
(4)Subset: 数据集的子集类,用于从数据集中选择指定的样本。
(5)random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
(6)ConcatDataset