当前位置: 首页 > news >正文

PyTorch 创建数据集

图片数据和标签数据准备

1.本文所用图片数据在同级文件夹中 ,文件路径为'train/’

在这里插入图片描述

2.标签数据在同级文件,文件路径为'train.csv'

在这里插入图片描述

3。将标签数据提取

train_csv=pd.read_csv('train.csv')

创建继承类

第一步,首先创建数据类对象 此时可以想象为单个数据单元的创建 { 图像,标签}

在这里插入图片描述

继承的是Dataset类 (数据集类)

from torch.utils.data import Dataset
from PIL import Image          //从文件路径中提取图片所需要的函数class Imagedata(Dataset):        //继承Dataset类def __init__(self,df,dir,transform=None):     //往类里传输需要的数据必须在这定义,后面初始化函数才能使用传入的数据,//df表示传入的标签数据,dir表示图像数据文件地址,transform是图像增强的处理操作super().__init__()                      //声明后面操作需要用的数据self.df=df                           self.dir=dirself.transform=transformdef __len__(self):                     //模板函数,没什么卵用return len(self.df)def __getitem__(self, idex):           //将单个数据和标签整合到一块的初始化函数img_id=self.df.iloc[idex,0]        //图片的名称在df文件中,标签也在df的文件中,如下图,为的就是提出图像数据文件中的图片,否则从图片数据文件中一张一张提取出来很难,名称太长img=Image.open(self.dir+img_id)   //拿到了图片的整个完整地址  img=np.array(img)                //Image提取出来的为image类型,需要转换为numpy数组,才能存储到数据集中//上面两行也可以换为cv2.imread(dir),直接读取的数据就可以往里面存,避免了数据转换label=self.df.iloc[idex,1]       //从df中提取对应的标签,就是同一张图像的标签,由idex固定return img,label                 //返回整理好的单个数据单元(图像+标签)

在这里插入图片描述

第二步,创造好了单个数据单元对象,那么需要将多个数据单元整合起来构成一个完整的数据集

先将单个数据单元实现,因为上面的代码为类对象代码,并没有实现

train_dataset=ImageDataset(df=train_csv,dir='train/')  //df为标签文件,dir表示你图像存储的文件地址

得到了单个数据单元,那么开始将数据整合,先调用数据整合函数:

from torch.utils.data import DataLoader

通过数据流来整合

train_data=DataLoader(train_dataset,batch_size=32)    //train_dataset 为单个对象     batch_size为设置几个为一小组,为后面的分组训练做准备

那么最后得到的train_data就是带有图像和标签的数据集,可以验证一下:

for img,label in train_data:print(img,label)

在这里插入图片描述

图像增强技术(降噪,标准化)

上面没有加入图像增强代码,创建数据集时候,可以先将图像增强后再存入数据集,增强的主要目的就是提高训练准确率,标准化可以使图像在神经网络训练的更快,因为图像的数据明显变小,举个例子,由像素[233,221,222]可以直接变为[2.33,2.21,2.22]

如下使图像增强代码,用的使torchvision,每行代码都有注释

from torchvision import transformstransform_train = transforms.Compose([transforms.ToTensor(),        //将图像变为Tensor张量,并将图像像素由255-0变为1-0,压缩,并将图像的维度从 (H x W x C) 转换为 (C x H x W)transforms.Pad(32, padding_mode='symmetric')   //表示在图像的四周各填充 32 个像素。transforms.RandomHorizontalFlip(),    //以一定的概率对图像进行随机水平翻转。这有助于增加数据的多样性,提高模型的泛化能力。防止拟合transforms.RandomVerticalFlip(),      //以一定的概率对图像进行随机垂直翻转。同样是为了增加数据多样性transforms.RandomRotation(10),       //以一定的概率对图像进行随机旋转,旋转角度在 -1010 度之间。增加数据的多样性transforms.Normalize((0.485, 0.456, 0.406),     //指定每个通道的均值。通常是在 ImageNet 数据集上计算得到的均值。(0.229, 0.224, 0.225))])   //指定每个通道的标准差。也是在 ImageNet 数据集上计算得到的标准差。

那么在数据单元创建的时候加入,以下是完整代码:

from torch.utils.data import Datasetclass ImageDataset(Dataset):def __init__(self, df, dir, transform=None): super().__init__()self.df = dfself.dir = dirself.transform = transformdef __len__(self):return len(self.df)def __getitem__(self, idx):img_id = self.df.iloc[idx,0]img_path = self.dir + img_idimage = cv2.imread(img_path)            //这里用了cv2直接读取图片,避免了转换numpyimage = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   //opencv里的数据增强label = self.df.iloc[idx,1]if self.transform is not None:image = self.transform(image)return image, label-----------------------图像增强技术------------------------
from torchvision import transforms
transform_train = transforms.Compose([transforms.ToTensor(),transforms.Pad(32, padding_mode='symmetric'),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(10),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])
transform_test = transforms.Compose([transforms.ToTensor(),transforms.Pad(32, padding_mode='symmetric'),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))])from torch.utils.data import DataLoader
dataset_train = ImageDataset(df=train_df, img_dir='train/',transform=transform_train)
loader_train = DataLoader(dataset=dataset_train, batch_size=32, shuffle=True)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 《论系统安全架构设计及其应用》写作框架,软考高级系统架构设计师
  • 面经学习(hbkj实习)
  • 如何在Mac中修改pip的镜像源
  • 【MySQL】批量插入数据造数-存储过程
  • 在Windows系统上部署PPTist并实现远程访问
  • IntelliJ IDEA下载安装
  • 01 Shell Script概述
  • HTTP 三、http在springboot中得应用
  • 好看的个人导航页面html源码
  • 使用Fign进行客户端远程调用和SpringFormEncoder的使用
  • Docker Container 常用命令
  • 新型PyPI攻击技术可能导致超2.2万软件包被劫持
  • 服务器/linux上登录huggingface网站
  • [UVM]5.config机制 report 消息管理
  • docker装大米cms(damicms)各种cms可用相同办法
  • [ 一起学React系列 -- 8 ] React中的文件上传
  • css的样式优先级
  • Java 网络编程(2):UDP 的使用
  • Promise初体验
  • Redis 中的布隆过滤器
  • vuex 笔记整理
  • 分布式熔断降级平台aegis
  • 分享几个不错的工具
  • 干货 | 以太坊Mist负责人教你建立无服务器应用
  • 构建工具 - 收藏集 - 掘金
  • 汉诺塔算法
  • 将 Measurements 和 Units 应用到物理学
  • 买一台 iPhone X,还是创建一家未来的独角兽?
  • 使用iElevator.js模拟segmentfault的文章标题导航
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 用jquery写贪吃蛇
  • 白色的风信子
  • ​经​纬​恒​润​二​面​​三​七​互​娱​一​面​​元​象​二​面​
  • #LLM入门|Prompt#1.7_文本拓展_Expanding
  • (1)(1.13) SiK无线电高级配置(五)
  • (Java数据结构)ArrayList
  • (Redis使用系列) Springboot 使用redis实现接口Api限流 十
  • (八)Flask之app.route装饰器函数的参数
  • (纯JS)图片裁剪
  • (附源码)计算机毕业设计ssm高校《大学语文》课程作业在线管理系统
  • (数据大屏)(Hadoop)基于SSM框架的学院校友管理系统的设计与实现+文档
  • (四)图像的%2线性拉伸
  • (转)我也是一只IT小小鸟
  • . NET自动找可写目录
  • .NET 4.0中使用内存映射文件实现进程通讯
  • .Net CF下精确的计时器
  • .NET CLR Hosting 简介
  • .NET Micro Framework初体验(二)
  • .Net MVC4 上传大文件,并保存表单
  • .NET 设计模式初探
  • .NET 中 GetHashCode 的哈希值有多大概率会相同(哈希碰撞)
  • .netcore 获取appsettings
  • .NET版Word处理控件Aspose.words功能演示:在ASP.NET MVC中创建MS Word编辑器
  • .NET程序集编辑器/调试器 dnSpy 使用介绍
  • .net开发引用程序集提示没有强名称的解决办法