当前位置: 首页 > news >正文

PyTorch DataLoader 学习

1. DataLoader的核心概念

DataLoader是PyTorch中一个重要的类,用于将数据集(dataset)和数据加载器(sampler)结合起来,以实现批量数据加载和处理。它可以高效地处理数据加载、多线程加载、批处理和数据增强等任务。

核心参数

  • dataset: 数据集对象,必须是继承自torch.utils.data.Dataset的类。
  • batch_size: 每个批次的大小。
  • shuffle: 是否在每个epoch开始时打乱数据。
  • sampler: 定义数据加载顺序的对象,通常与shuffle互斥。
  • num_workers: 使用多少个子进程加载数据。
  • collate_fn: 如何将单个样本合并为一个批次的函数。
  • pin_memory: 是否将数据加载到CUDA固定内存中。

2. 基本使用方法

定义数据集类

首先定义一个数据集类,该类需要继承自torch.utils.data.Dataset并实现__len____getitem__方法。

import torch
from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample# 创建一些示例数据
data = torch.randn(100, 3, 64, 64)  # 100个样本,每个样本为3x64x64的图像
labels = torch.randint(0, 2, (100,))  # 100个标签,0或1dataset = CustomDataset(data, labels)

创建DataLoader

使用自定义数据集类创建DataLoader对象。

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

迭代DataLoader

遍历DataLoader获取批量数据。

for batch in dataloader:data, labels = batch['data'], batch['label']print(data.shape, labels.shape)

3. 进阶技巧

自定义collate_fn

如果需要自定义如何将样本合并为批次,可以定义自己的collate_fn函数。

def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

使用Sampler

Sampler定义了数据加载的顺序。可以自定义一个Sampler来实现更复杂的数据加载策略。

from torch.utils.data import Samplerclass CustomSampler(Sampler):def __init__(self, data_source):self.data_source = data_sourcedef __iter__(self):return iter(range(len(self.data_source)))def __len__(self):return len(self.data_source)custom_sampler = CustomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=custom_sampler, num_workers=2)

数据增强

在图像处理中,数据增强(Data Augmentation)是提高模型泛化能力的一种有效方法。可以使用torchvision.transforms进行数据增强。

import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

4. 实战示例:CIFAR-10数据集

以下是使用CIFAR-10数据集的完整示例代码,包括数据加载、数据增强和模型训练。

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10# 定义数据增强和标准化
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 加载训练和测试数据集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义简单的卷积神经网络
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100}')running_loss = 0.0print('Finished Training')# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

5. 数据加载加速技巧

使用多进程数据加载

通过设置num_workers参数,可以启用多进程数据加载,加速数据读取过程。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

如果使用GPU进行训练,将pin_memory设置为True可以加速数据传输。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

使用prefetch_factor参数来预取数据,以减少数据加载等待时间。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

6. 处理不规则数据

在某些情况下,数据样本可能不规则,例如变长序列。可以使用自定义的collate_fn来处理这种数据。

def custom_collate_fn(batch):batch = sorted(batch, key=lambda x: len(x['data']), reverse=True)data = [item['data'] for item in batch]labels = [item['label'] for item in batch]data_padded = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)labels = torch.tensor(labels)return {'data': data_padded, 'label': labels}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

7. 使用中应注意的问题

数据加载效率

设置num_workers

  • 多线程数据加载: num_workers参数决定了用于数据加载的子进程数量。合理设置num_workers可以显著提升数据加载速度。一般来说,设置为CPU核心数的一半或等于核心数是一个不错的选择,但需要根据具体情况进行调整。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

  • 固定内存: 当使用GPU进行训练时,将pin_memory设置为True可以加速数据从CPU传输到GPU的速度。固定内存使得数据可以直接从页面锁定内存复制到GPU内存。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

  • 预取因子: 使用prefetch_factor参数来预取数据,以减少数据加载等待时间。默认情况下,预取因子为2。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

数据集与DataLoader的兼容性

正确实现 __getitem____len__

  • 数据集类的实现: 确保自定义数据集类正确实现了__getitem____len__方法,确保DataLoader能够正确地索引和迭代数据。
class CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample

数据增强与预处理

数据增强

  • 变换操作: 在图像处理中,数据增强可以提高模型的泛化能力。可以使用torchvision.transforms进行数据增强和标准化。
import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

数据加载过程中的内存问题

避免内存泄漏

  • 防止内存泄漏: 在使用DataLoader时,尤其是多进程加载时,注意内存泄漏问题。确保在训练过程中及时释放不再使用的数据。

合理设置batch_size

  • 批次大小: 根据GPU显存和内存大小合理设置batch_size。过大可能导致内存不足,过小可能导致计算效率低。
batch_size = 64  # 根据实际情况调整
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据顺序与随机性

shufflesampler

  • 数据随机性: 在训练集上使用shuffle=True,可以在每个epoch开始时打乱数据,防止模型过拟合。
  • 使用Sampler: 对于特殊的数据加载顺序需求,可以自定义Sampler。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据不一致性

自定义collate_fn

  • 处理变长序列:在处理变长序列或不规则数据时,自定义collate_fn函数,确保每个批次的数据能够正确合并。
def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

数据加载调试

调试与错误处理

  • 调试: 在数据加载过程中,可以打印或检查部分数据样本,确保数据预处理和加载过程正确无误。
  • 错误处理: 使用try-except块捕捉并处理数据加载中的异常,防止程序崩溃。
for i, data in enumerate(dataloader, 0):try:inputs, labels = data['data'], data['label']# 数据处理和训练代码except Exception as e:print(f"Error loading data at batch {i}: {e}")

性能优化

数据加载性能

  • Profile数据加载: 使用profiling工具(如PyTorch的torch.utils.bottleneck)分析数据加载和训练过程中的性能瓶颈,进行相应优化。
import torch.utils.bottleneck# 在命令行运行以下命令进行性能分析
# python -m torch.utils.bottleneck <script.py>

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 2024前端面试真题【CSS篇】
  • 【数据结构】线性表----队列详解
  • 【2024_CUMCM】时间序列3-一元时间序列分析的模型
  • Spring容器加载Bean和JVM加载类
  • 【网络安全】Oracle:SSRF获取元数据
  • Python编程学习笔记(3)--- 操作列表
  • C++的入门基础(二)
  • vue 画二维码及长按保存
  • 基于TCP的在线词典系统(分阶段实现)(阻塞io和多路io复用(select)实现)
  • 【Linux】 GCC/G++与Makefile使用
  • Android Spinner
  • 数据结构和算法(0-1)----递归
  • ArduPilot开源代码之OpticalFlow_backend
  • arm64架构下源码编译安装kafka —— 筑梦之路
  • 【C++】———— 继承
  • JS 中的深拷贝与浅拷贝
  • -------------------- 第二讲-------- 第一节------在此给出链表的基本操作
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • axios 和 cookie 的那些事
  • iOS帅气加载动画、通知视图、红包助手、引导页、导航栏、朋友圈、小游戏等效果源码...
  • mongo索引构建
  • MySQL数据库运维之数据恢复
  • PAT A1120
  • php面试题 汇集2
  • Spark学习笔记之相关记录
  • 区块链分支循环
  • 什么软件可以剪辑音乐?
  • 使用Tinker来调试Laravel应用程序的数据以及使用Tinker一些总结
  • 通过git安装npm私有模块
  • 在 Chrome DevTools 中调试 JavaScript 入门
  • 【干货分享】dos命令大全
  • 哈罗单车融资几十亿元,蚂蚁金服与春华资本加持 ...
  • ​​​​​​​开发面试“八股文”:助力还是阻力?
  • ​ubuntu下安装kvm虚拟机
  • !!Dom4j 学习笔记
  • # Java NIO(一)FileChannel
  • # windows 安装 mysql 显示 no packages found 解决方法
  • #快捷键# 大学四年我常用的软件快捷键大全,教你成为电脑高手!!
  • #数据结构 笔记一
  • (02)vite环境变量配置
  • (差分)胡桃爱原石
  • (多级缓存)多级缓存
  • (接口封装)
  • (十)c52学习之旅-定时器实验
  • (十三)Java springcloud B2B2C o2o多用户商城 springcloud架构 - SSO单点登录之OAuth2.0 根据token获取用户信息(4)...
  • (四)库存超卖案例实战——优化redis分布式锁
  • (一)WLAN定义和基本架构转
  • *(长期更新)软考网络工程师学习笔记——Section 22 无线局域网
  • .NET Micro Framework初体验(二)
  • .net 连接达梦数据库开发环境部署
  • .NET/C# 在 64 位进程中读取 32 位进程重定向后的注册表
  • .NET8使用VS2022打包Docker镜像
  • @DataRedisTest测试redis从未如此丝滑
  • [ C++ ] template 模板进阶 (特化,分离编译)