当前位置: 首页 > news >正文

数据集的读取和处理

一、定义数据集中图像预处理操作: 

import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomCrop(32, padding=4, padding_mode='reflect'),transforms.RandomHorizontalFlip(),transforms.ToTensor()])

调用torchvision库解决,其中transforms指明对数据集中图片的预处理操作,常见的可选的预处理操作如下:

1.transforms.CenterCrop(size)
(1)将给定的PIL.Image对象进行中心切割,得到给定的size;
(2)其中size可以是tuple类型,即size = (target_height, target_width);
(3)size也可以是int类型,此时切出来的图片的形状是正方形;
2.transforms.RandomCrop(size,padding=None,pad_if_needed=False,fill=0, padding_mode='constant'):
(1)从图片中随机裁剪出尺寸为size的图片
(2)size:所需裁剪图片尺寸
(3)padding:设置填充大小,有如下的三种格式:当为a时上下左右均填充a个像素;当为(a,b)时上下填充b个像素,左右填充a个像素;当为(a,b,c,d)时左、上、右、下分别填充a,b,c,d个像素
(4)pad_if_need:若图像小于设定size则按照填充模式padding_mode进行填充填充模式有4种:constant:像素值由fill设定edge:像素值由图像边缘像素决定reflect:镜像填充,最后一个像素不镜像symmetric:镜像填充,最后一个像素镜像
(5)fill:填充模式为constant时设置填充的像素值
3.transforms.RandomHorizontalFlip():
随机水平翻转给定的PIL.Image对象;
翻转的概率为0.5,即一半的概率翻转,一半的概率不翻转;
4.transforms.RandomResizedCrop(size,scale,ratio,interpolation)
(1)按照随机大小和长宽比裁剪图片
(2)首先根据scale的比例缩放原图,然后根据ratio的长宽比裁剪,最后使用插值法把图片变换为size大小
(3)size:所需裁剪图片尺寸
(4)scale:随机缩放面积比例, 默认(0.08,1)
(5)ratio:随机长宽比,默认(3/4,4/3)
(6)interpolation:插值方法PIL.Image.NEARESTPIL.Image.BILINEARPIL.Image.BICUBIC
5.transforms.FiveCrop(size)
(1)在原始图片的左上右上左下右下中间裁剪size大小的图片
(2)size:所需裁剪图片尺寸
6.transforms.TenCrop(size,vertical_flip=False)
(1)在图像的上下左右以及中心裁剪出尺寸为size的5张图片,并对这5张图片进行水平或者垂直镜像获得10张图片
(2)size:所需裁剪图片尺寸
(3)vertical_flip:是否垂直翻转
7.transforms.RandomHorizontalFlip(p=0.5)
(1)依概率水平(左右)翻转图片
(2)p:翻转概率
8.transforms.RandomVerticalFlip(p=0.5)
(1)依概率垂直(上下)翻转图片
(2)p:翻转概率
9.transforms.RandomRotation(degrees,resample,expand,center)
(1)随机旋转图片
(2)degrees:旋转角度当为a时在(-a,a)之间选择旋转角度当为(a,b)时在(a,b)之间选择旋转角度
(3)resample:重采样方法
(4)expand:是否扩大图片以保持原图信息
(5)center:旋转点设置,默认中心旋转
10.transforms.Pad(padding,fill=0,padding_mode='constant')
(1)对图像边缘进行填充
(2)padding:设置填充大小当为a时上下左右均填充a个像素;当为(a,b)时上下填充b个像素,左右填充a个像素;当为(a,b,c,d)时左、上、右、下分别填充a,b,c,d个像素
(3)padding_mode:填充模式填充模式有4种:constant:像素值由fill设定edge:像素值由图像边缘像素决定reflect:镜像填充,最后一个像素不镜像symmetric:镜像填充,最后一个像素镜像
(4)fill:padding_mode为constant时设置填充的像素值
11.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
(1)调整亮度,对比度,饱和度和色相
(2)brightness:亮度调整因子当为a时从[max(0,1-a),1+a]中随机选择;当为(a,b)时从[a,b]中随机选择;
(3)contrast:对比度参数,同brightness
(4)saturation:饱和度参数,同brightness
(5)hue:色相参数当为a时从[-a,a]中选择参数(0<=a<=0.5);当为(a,b)时从[a,b]中随机选择(-0.5<=a<=b<=0.5)
12.transforms.RandomGrayscale(num_output_channels,p)
(1)依据概率将图片转化为灰度图
(2)num_output_channels:输出通道数,只能设置为1或3;
(3)p:概率;
13.transforms.RandomAffine(degress,translate=None,scale=None,shear=None,resample=False, fillcolor)
(1)对图像进行仿射变换,仿射变换是二维的线性变换,有五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
(2)degrees:旋转角度设置
(3)translate:平移区间设置,如(a,b),a设置宽,b设置高;
(4)scale:缩放比例(以面积为单位)
(5)fill_color:填充颜色设置
(6)shear:错切角度设置,有水平错切和垂直错切若为a则仅在x轴错切,错切角度在(-a,a)之间;若为(a,b),则设置a为x轴角度,b设置y轴角度;若为(a,b,c,d),则a和b设置x轴角度,c和d设置y轴角度
(7)resample:重采样方法
14.transforms.RandomErasing(p,scale,ratio,value,inplace)
(1)对图像进行随机遮挡
(2)p:概率
(3)scale:遮挡区域的面积
(4)ratio:遮挡区域的长宽比
(5)value:设置遮挡区域的像素值
注意:这个方法是对张量进行操作的,所以在这个方法前面需要transforms.ToTensor()
15.transforms.Resize(size)
(1)这里的size可以是一个整数,表示将图像的较短边缩放到指定长度,同时保持长宽比
(2)size也可以是一个列表,格式为[width, height],表示将图像的宽度和高度调整为指定的尺寸
(3)在调整大小时图像的长宽比可能会发生改变,因此图像可能会被拉伸或压缩来适应指定的大小
16.transforms.Normalize(mean,std)
(1)其中mean和std均为一个长度为3的一维列表,分别对应图像的三个通道;
(2)该方法简单来说就是将数据按通道进行计算;
(3)每一个通道内的每一个数据减去均值mean再除以方差std,得到归一化后的结果;
17.transforms.ToTensor()
(1)将PIL.Image对象或numpy.ndarray对象转为tensor对象
(2)如果PIL.Image对象属于(L,LA,P,I,F,RGB,YCbCr,RGBA,CMYK,1)中的一种图像类型,或者numpy.ndarray对象格式数据类型是np.uint8
则将[0, 255]的数据转为[0.0, 1.0],也就是说将所有数据除以255进行归一化
(3)将HWC的图像格式转为CHW的tensor格式,经过ToTensor()处理的图像可以直接输入到CNN网络中,不需要再进行reshape
18.transforms.RandomChoice([transforms1, transforms2, transforms3])
(1)从一系列transforms方法中随机挑选一个
19.transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)
(1)依据概率执行一组transforms操作
20.transforms.RandomOrder([transforms1, transforms2, transforms3])
(1)对一组transforms操作打乱顺序

二、下载并加载对应的数据集:

# 下载并加载CIFAR-10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=0)
# 下载并加载CIFAR-10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,shuffle=False, num_workers=0)

主要的函数:

1.下载数据集:

1.CIFAR-10/CIFAR-100数据集:
torchvision.datasets.CIFAR10(root,train,transform,target_transform,download)
(1)root(String):数据集的根目录,其中目录cifar-10-batches-py存在或将保存到(如果下载设置为True);
(2)train(bool,可选):如果为True,则下载训练集,否则下载测试集;
(3)transform(可调用,可选):接受PIL图像并返回转换版本的函数/转换;是上一个步骤得到的transform对象
(4)target_transform(可调用,可选):接收目标并对其进行转换的函数/transform;
(5)download(bool,可选):如果为true则从Internet下载数据集并将其放在根目录中;如果数据集已下载则不会再次下载;
2.ImageNet数据集:
torchvision.datasets.ImageNet(root,split,**kwargs)
(1)root(String):ImageNet数据集的根目录;
(2)split(String,可选):数据集分割,可选值为'train'或'val';
(3)transform(可调用,可选):接受PIL图像并返回转换版本的函数/转换;是上一个步骤得到的transform对象;
(4)target_transform(可调用,可选):接收目标并对其进行转换的函数/transform;
(5)loader:加载给定路径的图像的函数;

2.加载数据集得到DataLoader对象: 

torch.utils.data.DataLoader(dataset,batch_size,shuffle,sampler,batch_sampler,num_workers,collate_fn,pin_memory,drop_last,timeout,worker_init_fn)
(1)dataset(Dataset):加载的数据集;
(2)batch_size(int,optional);每一次处理加载多少数据,即一个batch的大小
(3)shuffle(bool,optional):True表示每次个epoch遍历数据集都要重新打乱数据,默认为False;
(4)sampler(Sampler or Iterable,optional):定义采样的策略,如果定义了此参数那么shuffle参数必须为False;
(5)batch_sampler(Sampler or Iterable,optional):同sampler一样,但每次返回数据的索引;
(6)num_workers(int,optional):指定用于数据加载的子进程数,可以加快数据加载速度;默认为0表示用主进程加载;
(7)collate_fn(Callable,optional):批处理函数,用于将多个样本合并成一个批次,例如将多个张量拼接在一起构建mini-batch;
(8)pin_memory(bool,optional):True表示在返回张量之前将张量复制到CUDA固定的内存中,加快GPU传输速度;
(9)drop_last(bool, optional):True表示可删除最后一个不完整的批次,默认为False,如果数据集的大小不能被批次大小整除则最后一个批次小于batch_size;
(10)timeout(numeric,optional):非负数,表示worker收集批次数据的超时时间,默认为0
(11)worker_init_fn (Callable, optional):如果非None则在种子设定之后和数据加载之前将以worker id([0,num_workers-1]中的int)作为输入对每个worker子进程调用此函数
(12)multiprocessing_context(str or multiprocessing.context.BaseContext, optional):如果为None则将使用操作系统的默认多处理上下文;
(13)generator(torch.Generator,optional):如果非None则RandomSampler将使用此RNG来生成随机索引,并进行多进程处理以为workers生成base_seed;
(14)prefetch_factor(int, optional, keyword-only arg):每个worker预先装载的批次数,2表示在所有工作线程中总共预取2*num_workers批次;
(15)persistent_workers(bool, optional):True表示不会在数据集使用一次后关闭工作进程;这允许保持 worker实例处于活动状态(默认值:False)
(16)pin_memory_device(str, optional):如果pin_memory为True该参数表示pin_memory所指向的设备

说明:

a.sampler和batch_sampler都为None:batch_sampler使用Pytorch实现的批采样,而sampler分为两种情况:

shuffle=True:sampler使用随机采样;

shuffle=False:sampler使用顺序采样

b.自定义了batch_sampler,那么batch_size,shuffle,sampler,drop_last必须都是默认值;

c.自定义了sampler,此时batch_sampler不能再指定,且shuffle必须为False;

3.获取sampler对象:

1.torch.utils.data.SequentialSampler(data_source):
(1)顺序采样,用于获取数据索引,按照顺序返回数据集索引;
(2)data_source:可迭代对象,可以为数据集或列表;
2.torch.utils.data.RandomSampler(data_source,num_samples,replacement):
(1)用于获取打乱的数据索引,乱序返回数据集索引;
(2)data_source:可迭代对象,可以为数据集或列表;
(3)num_samples:指定采样的数量,默认是全部;
(4)replacement:若为True则表示可以重复采样,即同一个样本可以重复采样;
3.torch.utils.data.BatchSampler(sampler,batch_size,drop_last):
(1)批采样,将sampler采样得到的单个的索引值进行合并,当数量等于一个batch大小后就将这一批的索引值返回;
(2)sampler:上述两种采样器,即SequentialSampler或RandomSampler;
(3)batch_size:batch的大小;
(4)drop_last:可以为True或False,drop_last为True时如果采样得到的数据个数小于batch_size则抛弃本个batch的数据;
4.torch.utils.data.SubsetRandomSampler(indices):
(1)子集随机采样,与上面返回数据的索引不同,这里返回的是对应索引的数据本身
(2)indices:数据集索引
5.torch.utils.data.WeightedRandomSampler(weights,num_samples,replacement):
(1)加权随机采样,这里返回的是对应索引的数据本身;
(2)weights:采样到该索引的权重;
(3)num_samples:指定采样的数量;
(4)replacement:若为True则表示可以重复采样,即同一个样本可以重复采样;

相关文章:

  • 【微机原理及接口技术】可编程计数器/定时器8253
  • C++标准模板(STL)- C 内存管理库 - 分配并清零内存 (std::calloc)
  • 怎么从视频中提取音频?这里有三种提取妙招
  • 19 - grace数据处理 - 补充 - 地下水储量计算过程分解 - 冰后回弹(GIA)改正
  • 代码随想录算法训练营第22天(py)| 二叉树 | 669. 修剪二叉搜索树、108.将有序数组转换为二叉搜索树、538.把二叉搜索树转换为累加树
  • Golang项目代码组织架构实践
  • 第一节:Redis的数据类型和基本操作
  • IPFoxy Tips:海外代理IP适用的8个跨境出海业务
  • C#多线程同步lock、Mutex
  • 如何配置才能连接远程服务器上的 redis server ?
  • 继承基础实战
  • 网站工作原理
  • 四数之和-力扣
  • qmt量化交易策略小白学习笔记第4期【qmt如何获取获取行情数据--内置python使用方法】
  • 爬虫案例(读书网)
  • 【JavaScript】通过闭包创建具有私有属性的实例对象
  • AHK 中 = 和 == 等比较运算符的用法
  • Angular 响应式表单之下拉框
  • GDB 调试 Mysql 实战(三)优先队列排序算法中的行记录长度统计是怎么来的(上)...
  • Github访问慢解决办法
  • Javascript设计模式学习之Observer(观察者)模式
  • JDK 6和JDK 7中的substring()方法
  • jquery ajax学习笔记
  • js学习笔记
  • js中forEach回调同异步问题
  • 服务器从安装到部署全过程(二)
  • 关于extract.autodesk.io的一些说明
  • 和 || 运算
  • 坑!为什么View.startAnimation不起作用?
  • 浏览器缓存机制分析
  • 那些年我们用过的显示性能指标
  • 前端面试之CSS3新特性
  • 如何编写一个可升级的智能合约
  • 如何正确配置 Ubuntu 14.04 服务器?
  • 腾讯大梁:DevOps最后一棒,有效构建海量运营的持续反馈能力
  • 为视图添加丝滑的水波纹
  • 用mpvue开发微信小程序
  • MiKTeX could not find the script engine ‘perl.exe‘ which is required to execute ‘latexmk‘.
  • 如何在招聘中考核.NET架构师
  • #QT(TCP网络编程-服务端)
  • #WEB前端(HTML属性)
  • $LayoutParams cannot be cast to android.widget.RelativeLayout$LayoutParams
  • (3)Dubbo启动时qos-server can not bind localhost22222错误解决
  • (32位汇编 五)mov/add/sub/and/or/xor/not
  • (4)logging(日志模块)
  • (4)通过调用hadoop的java api实现本地文件上传到hadoop文件系统上
  • (cljs/run-at (JSVM. :browser) 搭建刚好可用的开发环境!)
  • (附源码)springboot优课在线教学系统 毕业设计 081251
  • (附源码)计算机毕业设计ssm本地美食推荐平台
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (六)软件测试分工
  • (生成器)yield与(迭代器)generator
  • (十)DDRC架构组成、效率Efficiency及功能实现
  • (五)大数据实战——使用模板虚拟机实现hadoop集群虚拟机克隆及网络相关配置
  • (一)SvelteKit教程:hello world