【AI】PyTorch入门(三):数据集和数据加载器
1、简述
PyTorch中对数据集的描述类都是从torch.utils.data.Dataset继承,然后交给torch.utils.data.DataLoader来处理。
在模型训练时,只操作 torch.utils.data.DataLoader 即可,这样就将数据集的代码和训练的代码分离,方便维护。
2、PyTorch预加载数据集
PyTorch针对图像、音频、文本都提供了常用的内置数据集。这些数据集继承自torch.utils.data.Dataset,并且实现了如下方法,可以传递给torch.utils.data.DataLoader
__getitem__
__len__
数据加载示例:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
以图像数据集为例:
导入图像数据集类:torchvision.datasets
数据集类别有图像分类、图像检测和分割、光流、图像配对、图片说明、视频分类
官网文档:https://pytorch.org/vision/stable/datasets.html
3、加载数据集
以FashionMNIST为例.
Fashion-MNIST:替代MNIST手写数字集的图像数据集,该数据集由衣服、鞋子等服饰组成,包含70000张图像,其中60000张训练图像加10000张测试图像,图像大小为28x28,单通道,共分10个类,如下图,每三行为一类。
加载代码:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data", # 存储训练/测试数据的路径
train=True, # 指定训练=True或测试数据=False
download=True, # 如果数据不可用,则从网络下载数据到root指定目录中
transform=ToTensor() # 转换特征值,类似的 target_transform 用于转换标签
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
4、数据可视化
FashionMNIST标签是0~9,分别对应10种服装,如下面的labels_map
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# figsize:指定figure的宽和高,单位为英寸,下面有函数详解
figure = plt.figure(figsize=(8, 8))
# 显示3x3=9张图片
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
# 获取一个随机索引号
sample_idx = torch.randint(len(training_data), size=(1,)).item()
# 获取随机索引号对应的图像和标签
img, label = training_data[sample_idx]
# 添加子画布,原型add_subplot(nrows, ncols, index, **kwargs)
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
# 不显示坐标,下面有张显示坐标的截图
plt.axis("off")
# 处理图像
plt.imshow(img.squeeze(), cmap="gray")
# 显示画面
plt.show()
plt.figure函数原型:
figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True)
参数:
num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;1英寸等于2.5cm
dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框
显示坐标的话,是这样紫的
5、自定义数据
自定义 Dataset 类必须实现三个函数:
__init__:在实例化 Dataset 对象时运行一次
__len__:数据集数据个数
__getitem__:获取指定数据集
下面的自定义数据,描述文件为CSV,以逗号分割的数据,第一列为图片文件名,第二列为图片的标注信息,例如:
20220102001.jpg, car
20220102002.jpg, people
...
import os
# pandas 数据处理,可以处理CSV格式数据
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
# csv格式:第一列为img_labels.iloc[idx, 0]为图片文件名称
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
## csv格式:第二列为img_labels.iloc[idx, 1]为图片对应的标识
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
6、加载自定义数据
from torch.utils.data import DataLoader
train_dataset = CustomImageDataset("train.csv", True)
test_dataset = CustomImageDataset("test.csv", True)
# shuffle:是否打乱顺序,训练时需要打乱,测试时不需要
train_dataloader = DataLoader(train_dataset , batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset , batch_size=64)