当前位置: 首页 > news >正文

【AI】PyTorch入门(三):数据集和数据加载器

1、简述

PyTorch中对数据集的描述类都是从torch.utils.data.Dataset继承,然后交给torch.utils.data.DataLoader来处理。
在模型训练时,只操作 torch.utils.data.DataLoader 即可,这样就将数据集的代码和训练的代码分离,方便维护。

2、PyTorch预加载数据集

PyTorch针对图像、音频、文本都提供了常用的内置数据集。这些数据集继承自torch.utils.data.Dataset,并且实现了如下方法,可以传递给torch.utils.data.DataLoader

__getitem__
__len__

数据加载示例:

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

以图像数据集为例:
导入图像数据集类:torchvision.datasets
数据集类别有图像分类、图像检测和分割、光流、图像配对、图片说明、视频分类

官网文档:https://pytorch.org/vision/stable/datasets.html

3、加载数据集

以FashionMNIST为例.
Fashion-MNIST:替代MNIST手写数字集的图像数据集,该数据集由衣服、鞋子等服饰组成,包含70000张图像,其中60000张训练图像加10000张测试图像,图像大小为28x28,单通道,共分10个类,如下图,每三行为一类。
在这里插入图片描述
加载代码:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data",			# 存储训练/测试数据的路径
    train=True,				# 指定训练=True或测试数据=False
    download=True, 			# 如果数据不可用,则从网络下载数据到root指定目录中
    transform=ToTensor() 	# 转换特征值,类似的 target_transform 用于转换标签
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

4、数据可视化

FashionMNIST标签是0~9,分别对应10种服装,如下面的labels_map

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
# figsize:指定figure的宽和高,单位为英寸,下面有函数详解
figure = plt.figure(figsize=(8, 8))

# 显示3x3=9张图片
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
	# 获取一个随机索引号
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    # 获取随机索引号对应的图像和标签
    img, label = training_data[sample_idx]
    # 添加子画布,原型add_subplot(nrows, ncols, index, **kwargs)
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    # 不显示坐标,下面有张显示坐标的截图
    plt.axis("off")
    # 处理图像
    plt.imshow(img.squeeze(), cmap="gray")
# 显示画面
plt.show()

在这里插入图片描述
plt.figure函数原型:

figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True)

参数:

num:图像编号或名称,数字为编号 ,字符串为名称
figsize:指定figure的宽和高,单位为英寸;1英寸等于2.5cm 
dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80      
facecolor:背景颜色
edgecolor:边框颜色
frameon:是否显示边框

显示坐标的话,是这样紫的
在这里插入图片描述

5、自定义数据

自定义 Dataset 类必须实现三个函数:

__init__:在实例化 Dataset 对象时运行一次
__len__:数据集数据个数
__getitem__:获取指定数据集

下面的自定义数据,描述文件为CSV,以逗号分割的数据,第一列为图片文件名,第二列为图片的标注信息,例如:

20220102001.jpg, car
20220102002.jpg, people
...
import os
# pandas  数据处理,可以处理CSV格式数据
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
    	# csv格式:第一列为img_labels.iloc[idx, 0]为图片文件名称
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        ## csv格式:第二列为img_labels.iloc[idx, 1]为图片对应的标识
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

6、加载自定义数据

from torch.utils.data import DataLoader
train_dataset = CustomImageDataset("train.csv", True)
test_dataset = CustomImageDataset("test.csv", True)
# shuffle:是否打乱顺序,训练时需要打乱,测试时不需要
train_dataloader = DataLoader(train_dataset , batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset , batch_size=64)

相关文章:

  • Windows系统下MySQL8.0版详细安装及配置教程
  • Qt5开发从入门到精通——第五篇四节( 文本编辑器 Easy Word 开发 V1.3详解 )
  • c#-WPF使用类和子类绑定到DataContext
  • 图像隐写,如何在图像中隐藏二维码
  • SPL工业智能:发现时序数据的异常
  • 【Linux】进程概念(万字详解)—— 冯诺依曼体系结构 | 操作系统 | 进程
  • 网络编程套接字-----实现网络间通信
  • 机器学习:详细推导支持向量机SVM原理+Python实现
  • mysql socket文件丢失处理或者mysql.sock被删除
  • 欧拉计划详解第506题:钟摆序列
  • 《Python3 网络爬虫开发实战》:二、HTML消息结构
  • 调试接口小技巧-通过接口调试工具去下载上传文件
  • 【C指针详解】进阶篇
  • 惊奇发现业务移动端在往小程序化发展
  • 啸叫检测的方法:基于DSP的实现
  • SegmentFault for Android 3.0 发布
  • [case10]使用RSQL实现端到端的动态查询
  • Asm.js的简单介绍
  • iOS编译提示和导航提示
  • JavaScript服务器推送技术之 WebSocket
  • JavaScript工作原理(五):深入了解WebSockets,HTTP/2和SSE,以及如何选择
  • JS 面试题总结
  • leetcode46 Permutation 排列组合
  • Meteor的表单提交:Form
  • Next.js之基础概念(二)
  • SegmentFault 2015 Top Rank
  • Spring Cloud(3) - 服务治理: Spring Cloud Eureka
  • vue 配置sass、scss全局变量
  • 浅析微信支付:申请退款、退款回调接口、查询退款
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 使用Swoole加速Laravel(正式环境中)
  • 你对linux中grep命令知道多少?
  • ​ 轻量应用服务器:亚马逊云科技打造全球领先的云计算解决方案
  • #在线报价接单​再坚持一下 明天是真的周六.出现货 实单来谈
  • (¥1011)-(一千零一拾一元整)输出
  • (01)ORB-SLAM2源码无死角解析-(56) 闭环线程→计算Sim3:理论推导(1)求解s,t
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (二)linux使用docker容器运行mysql
  • (分类)KNN算法- 参数调优
  • (附源码)springboot 个人网页的网站 毕业设计031623
  • (五)网络优化与超参数选择--九五小庞
  • (已解决)vue+element-ui实现个人中心,仿照原神
  • (转)总结使用Unity 3D优化游戏运行性能的经验
  • ***微信公众号支付+微信H5支付+微信扫码支付+小程序支付+APP微信支付解决方案总结...
  • *2 echo、printf、mkdir命令的应用
  • .NET Core 控制台程序读 appsettings.json 、注依赖、配日志、设 IOptions
  • .NET Framework 服务实现监控可观测性最佳实践
  • .NET 设计一套高性能的弱事件机制
  • .Net程序帮助文档制作
  • .NET框架
  • /bin/bash^M: bad interpreter: No such file ordirectory
  • [1127]图形打印 sdutOJ
  • [2008][note]腔内级联拉曼发射的,二极管泵浦多频调Q laser——
  • [Android]创建TabBar
  • [AR Foundation] 人脸检测的流程