当前位置: 首页 > news >正文

深度学习--数据处理dataloader介绍及代码分析

dataloader介绍

  • dataloader
    • 概述
    • collate_fn
      • 主要作用
      • 在代码中的使用
  • 代码详解
    • 代码解释
      • __init__函数
      • collate_fn
      • 详细说明
  • 完整代码

dataloader

概述

参考博客
DataLoader是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。
DataLoader是一个用于批量加载数据的工具,它可以将数据集分成多个小批量(mini-batch),并逐个加载,以适应模型训练的需要。
DataLoader主要用于两个关键任务:数据加载和批次处理

  • 数据加载:DataLoader可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动将数据集划分为小批次,从而减小内存需求,确保数据的高效加载。
  • 数据批次处理:每个批次由多个样本组成,可以并行地进行数据预处理和数据增强。这有助于提高模型训练的效率,同时确保每个批次的数据都经过适当的处理。

collate_fn

collate_fn 是一个自定义函数,用于在 PyTorch 的 DataLoader 中定义如何将单个样本组合成一个批次(batch)。具体来说,collate_fn 函数会在每次从 DataLoader 中取出一个批次的数据时被调用,用于对数据进行整理和转换。

主要作用

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成期望的数据格式。
将一个批次的数据样本整理成适合模型输入的格式,特别是将数据转换为 PyTorch 张量(Tensor),以便于后续的模型训练和推理。

  • 自定义数据堆叠:将单个样本组合成一个批次,处理数据的不同形状或类型。
  • 数据转换:在批次数据组成之前进行必要的转换操作,例如数据类型转换、数据增强等。

在代码中的使用

在本代码中,unet_dataset_collate 函数就是一个 collate_fn 函数。它的作用是将一个批次的数据样本(图像、PNG 数据和分割标签)整理成适合模型输入的格式。具体步骤包括将数据从列表转换为 NumPy 数组,再转换为 PyTorch 张量。

代码详解

# DataLoader中collate_fn使用
def unet_dataset_collate(batch):images      = []pngs        = []seg_labels  = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels

这段代码定义了一个名为 unet_dataset_collate 的函数,用于在 PyTorch 的 DataLoader 中自定义批处理方式。函数将一个批次的数据样本(batch)转换为适合模型输入的格式。

代码解释

__init__函数

在 DataLoader 中,init 函数的主要作用是初始化数据集对象,并为后续的数据加载和处理做好准备。
UnetDataset 类的 init 函数在 DataLoader 中的作用包括:

  • 数据集初始化:通过传入的参数(如 annotation_lines、input_shape 等)初始化数据集对象,使其包含所有必要的信息。
  • 数据预处理:在初始化过程中,可以对数据进行预处理,如归一化、裁剪等,以便后续的模型训练。
  • 数据分割:将数据集分割成训练集和验证集(通过 train 参数),以便在训练过程中进行模型评估。
  • 路径管理:通过 dataset_path 参数指定数据集的存储路径,方便数据的加载和管理。
# UnetDataset 类的初始化方法,接受五个参数:annotation_lines、input_shape、num_classes、train 和 dataset_path。def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
# super() 函数用于调用父类的初始化方法。在这里,它调用了 UnetDataset 类的父类的 __init__ 方法,确保父类的初始化逻辑也被执行。这对于继承自其他类的类非常重要。super(UnetDataset, self).__init__()
# self 代表类的实例。self.annotation_lines 将传入的 annotation_lines 参数赋值给实例属性 annotation_linesself.annotation_lines   = annotation_linesself.length             = len(annotation_lines)self.input_shape        = input_shapeself.num_classes        = num_classesself.train              = trainself.dataset_path       = dataset_path

解释 super 和 self

  • super
    super() 函数用于调用父类的方法。在多重继承的情况下,它确保正确调用父类的方法,避免重复调用。这里,它调用了 UnetDataset 类的父类的 init 方法。
  • self
    self 是类的实例的引用。它用于访问类的属性和方法。在类的方法中,self 必须作为第一个参数传递,以便方法能够访问实例的属性和其他方法。

collate_fn

# DataLoader中collate_fn使用
# 函数定义:net_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本batch。
def unet_dataset_collate(batch):
# 初始化列表:
# images = []:用于存储所有图像数据。
# pngs = []:用于存储所有 PNG 格式的数据。
# seg_labels = []:用于存储所有分割标签数据images      = []pngs        = []seg_labels  = []
# 遍历批次数据:
# 遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
# images.append(img):将图像数据添加到 images 列表中。
# pngs.append(png):将 PNG 数据添加到 pngs 列表中。
# seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)
#转换数据类型:
# 将 images 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
# 将 pngs 列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。
# 将 seg_labels 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
# 返回结果:
# 返回处理后的图像数据、PNG 数据和分割标签数据。return images, pngs, seg_labels

详细说明

  1. 函数定义

    • unet_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本 batch
  2. 初始化列表

    • images = []:用于存储所有图像数据。
    • pngs = []:用于存储所有 PNG 格式的数据。
    • seg_labels = []:用于存储所有分割标签数据。
  3. 遍历批次数据

    • for img, png, labels in batch::遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
    • images.append(img):将图像数据添加到 images 列表中。
    • pngs.append(png):将 PNG 数据添加到 pngs 列表中。
    • seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。
  4. 转换数据类型

    • images = torch.from_numpy(np.array(images)).type(torch.FloatTensor):将 images 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
    • pngs = torch.from_numpy(np.array(pngs)).long():将 pngs 列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。
    • seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor):将 seg_labels 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
  5. 返回结果

    • return images, pngs, seg_labels:返回处理后的图像数据、PNG 数据和分割标签数据。

完整代码

import osimport cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Datasetfrom utils.utils import cvtColor, preprocess_inputclass UnetDataset(Dataset):def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):super(UnetDataset, self).__init__()self.annotation_lines   = annotation_linesself.length             = len(annotation_lines)self.input_shape        = input_shapeself.num_classes        = num_classesself.train              = trainself.dataset_path       = dataset_pathdef __len__(self):return self.lengthdef __getitem__(self, index):annotation_line = self.annotation_lines[index]name            = annotation_line.split()[0]#-------------------------------##   从文件中读取图像#-------------------------------#jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg"))png         = Image.open(os.path.join(os.path.join(self.dataset_path, "SegmentationClass"), name + ".png"))#-------------------------------##   数据增强#-------------------------------#jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])png         = np.array(png)png[png >= self.num_classes] = self.num_classes#-------------------------------------------------------##   转化成one_hot的形式#   在这里需要+1是因为voc数据集有些标签具有白边部分#   我们需要将白边部分进行忽略,+1的目的是方便忽略。#-------------------------------------------------------#seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))return jpg, png, seg_labelsdef rand(self, a=0, b=1):return np.random.rand() * (b - a) + adef get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):image   = cvtColor(image)label   = Image.fromarray(np.array(label))#------------------------------##   获得图像的高宽与目标高宽#------------------------------#iw, ih  = image.sizeh, w    = input_shapeif not random:iw, ih  = image.sizescale   = min(w/iw, h/ih)nw      = int(iw*scale)nh      = int(ih*scale)image       = image.resize((nw,nh), Image.BICUBIC)new_image   = Image.new('RGB', [w, h], (128,128,128))new_image.paste(image, ((w-nw)//2, (h-nh)//2))label       = label.resize((nw,nh), Image.NEAREST)new_label   = Image.new('L', [w, h], (0))new_label.paste(label, ((w-nw)//2, (h-nh)//2))return new_image, new_label#------------------------------------------##   对图像进行缩放并且进行长和宽的扭曲#------------------------------------------#new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)scale = self.rand(0.25, 2)if new_ar < 1:nh = int(scale*h)nw = int(nh*new_ar)else:nw = int(scale*w)nh = int(nw/new_ar)image = image.resize((nw,nh), Image.BICUBIC)label = label.resize((nw,nh), Image.NEAREST)#------------------------------------------##   翻转图像#------------------------------------------#flip = self.rand()<.5if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)label = label.transpose(Image.FLIP_LEFT_RIGHT)#------------------------------------------##   将图像多余的部分加上灰条#------------------------------------------#dx = int(self.rand(0, w-nw))dy = int(self.rand(0, h-nh))new_image = Image.new('RGB', (w,h), (128,128,128))new_label = Image.new('L', (w,h), (0))new_image.paste(image, (dx, dy))new_label.paste(label, (dx, dy))image = new_imagelabel = new_labelimage_data      = np.array(image, np.uint8)#---------------------------------##   对图像进行色域变换#   计算色域变换的参数#---------------------------------#r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1#---------------------------------##   将图像转到HSV上#---------------------------------#hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))dtype           = image_data.dtype#---------------------------------##   应用变换#---------------------------------#x       = np.arange(0, 256, dtype=r.dtype)lut_hue = ((x * r[0]) % 180).astype(dtype)lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)lut_val = np.clip(x * r[2], 0, 255).astype(dtype)image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)return image_data, label# DataLoader中collate_fn使用
def unet_dataset_collate(batch):images      = []pngs        = []seg_labels  = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【C#】一个项目移动了位置,或者换到其他电脑上,编译报错 Files 的值“IGEF,解决方法
  • Elasticsearch 地理查询:高效探索空间数据
  • openstack使用笔记
  • antdv和element表格,假分页+表格高度处理mixins
  • springboot瑜伽课约课小程序-计算机毕业设计源码87936
  • 【数据结构与算法 | 力扣+二叉搜索树篇】力扣450, 98
  • C++中的::
  • 告别DockerHub 镜像下载难题:掌握高效下载策略,畅享无缝开发体验
  • 【Python深度学习】如何实现将将时间序列转换为图像的功能
  • 基于python的电商水果超市的设计与实现
  • 手机游戏录屏软件哪个好,3款软件搞定游戏录屏
  • Golang | Leetcode Golang题解之第327题区间和的个数
  • 数据库系统 第2节 数据库语言
  • 一篇文章教会你 LVS———NAT模式和DR模式部署配置
  • 【ES6】使用Set和Map进行全组合判断
  • JS 中的深拷贝与浅拷贝
  • iOS | NSProxy
  • iOS小技巧之UIImagePickerController实现头像选择
  • Java,console输出实时的转向GUI textbox
  • JavaScript 是如何工作的:WebRTC 和对等网络的机制!
  • java中的hashCode
  • Protobuf3语言指南
  • React16时代,该用什么姿势写 React ?
  • spring + angular 实现导出excel
  • TypeScript实现数据结构(一)栈,队列,链表
  • 基于axios的vue插件,让http请求更简单
  • 实现菜单下拉伸展折叠效果demo
  • 微信小程序上拉加载:onReachBottom详解+设置触发距离
  • 协程
  • 栈实现走出迷宫(C++)
  • 这几个编码小技巧将令你 PHP 代码更加简洁
  • CMake 入门1/5:基于阿里云 ECS搭建体验环境
  • Play Store发现SimBad恶意软件,1.5亿Android用户成受害者 ...
  • 树莓派用上kodexplorer也能玩成私有网盘
  • 通过调用文摘列表API获取文摘
  • ​LeetCode解法汇总2304. 网格中的最小路径代价
  • # Swust 12th acm 邀请赛# [ A ] A+B problem [题解]
  • #QT(QCharts绘制曲线)
  • (11)MATLAB PCA+SVM 人脸识别
  • (2024,LoRA,全量微调,低秩,强正则化,缓解遗忘,多样性)LoRA 学习更少,遗忘更少
  • (php伪随机数生成)[GWCTF 2019]枯燥的抽奖
  • (TOJ2804)Even? Odd?
  • (八)Spring源码解析:Spring MVC
  • (大众金融)SQL server面试题(1)-总销售量最少的3个型号的车及其总销售量
  • (二) 初入MySQL 【数据库管理】
  • (全注解开发)学习Spring-MVC的第三天
  • (三) diretfbrc详解
  • (数据结构)顺序表的定义
  • (详细文档!)javaswing图书管理系统+mysql数据库
  • (一)认识微服务
  • (一)项目实践-利用Appdesigner制作目标跟踪仿真软件
  • (转)菜鸟学数据库(三)——存储过程
  • (转载)虚函数剖析
  • *(长期更新)软考网络工程师学习笔记——Section 22 无线局域网
  • .DFS.