当前位置: 首页 > news >正文

利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

Python 提供了许多灵活的特性,例如包的 __init__.py 文件和 getattr 函数,这些特性可以帮助我们实现工厂方法模式来动态地创建不同类型的数据集实例。

1. 背景介绍

在深度学习项目中,我们通常需要处理多种类型的数据集,例如 COCO、Pascal VOC 和自定义的交通数据集。为了简化和统一数据集的加载过程,我们可以利用 Python 的包管理和动态属性获取特性来实现工厂方法模式。

  • 包的 __init__.py 文件:通过在包的 __init__.py 文件中导入模块,我们可以在初始化包时自动加载所有必要的类和函数。
  • getattr 函数getattr 函数允许我们动态地获取对象的属性或方法,这对于实现工厂方法模式非常有用,因为我们可以根据配置或输入动态地创建对象,而无需在代码中硬编码每种数据集的构建逻辑。

接下来,我们将通过具体的代码示例来展示如何使用这些特性来实现数据集的动态加载。

2. 模块和类的定义

在我们的项目中,数据集类被定义在 datasets 模块中。我们将定义一个 COCODataset 类,并在 datasets 模块的 __init__.py 文件中导入它。需要注意的是,COCODataset 只是众多数据集类中的一种,其他数据集类如 PascalVOCDatasetTrafficDataset 等也可以通过类似的方式定义和使用。

定义 COCODataset

datasets 模块中创建一个名为 coco.py 的文件,并定义 COCODataset 类。这个类继承自 torchvision.datasets.coco.CocoDetection,并添加了一些自定义逻辑。

# datasets/coco.py
import torchvisionclass COCODataset(torchvision.datasets.coco.CocoDetection):def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):super(COCODataset, self).__init__(root, ann_file)# 自定义逻辑...
  • __init__ 方法COCODataset 类的构造函数接受 ann_file(注释文件路径)、root(图像根目录)、remove_images_without_annotations(是否移除没有注释的图像)和 transforms(图像变换)四个参数。这些参数与后面 DatasetCatalogget 方法返回的 args 对应。
  • 详细实现见附录
导入 COCODataset

datasets 模块的 __init__.py 文件中导入 COCODataset 类。这样可以确保在使用 datasets 模块时,所有数据集类都已加载。

# datasets/__init__.py
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .traffic_dataset import TrafficDataset
from .carWinBiaoZhi_dataset import CarWinBiaoZhiDataset
from .carWinBiaoZhi_dataset_V2 import CarWinBiaoZhiDatasetV2
from .carWinBiaoZhi_dataset_V2_1 import CarWinBiaoZhiDatasetV2_1
from .GsData import CgTrafficData
from .GsData_xianQuan import CgTrafficDataWithXianQuan
from .GsData_1cls import CgTrafficData1Cls
from .GsData_ForSemi import CgTrafficDataSemi
from .GsRadarData import CgTrafficRadarData__all__ = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "TrafficDataset","CarWinBiaoZhiDataset", "CarWinBiaoZhiDatasetV2", "CarWinBiaoZhiDatasetV2_1", "CgTrafficData", "CgTrafficDataWithXianQuan", "CgTrafficDataSemi", "CgTrafficRadarData", "CgTrafficData1Cls"
]

3. 使用 getattr 动态获取工厂方法

在构建数据集实例时,我们通过 getattr 函数动态获取工厂方法。以下是实现这一过程的核心代码:

# build_dataset.py
from . import datasets as Ddef build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):if not isinstance(dataset_list, (list, tuple)):raise RuntimeError("dataset_list 应该是一个字符串列表,得到的是 {}".format(dataset_list))datasets = []  # 初始化数据集列表for dataset_name in dataset_list:# 从 dataset_catalog 中获取数据集信息data = dataset_catalog.get(dataset_name)# 获取数据集的工厂方法factory = getattr(D, data["factory"])# 获取数据集的参数args = data["args"]# 设置数据集的变换args["transforms"] = transforms# 使用工厂方法创建数据集实例dataset = factory(**args)# 将创建的数据集添加到列表中datasets.append(dataset)# 如果是测试模式,返回数据集列表if not is_train:return datasets# 如果是训练模式,将所有数据集合并为一个数据集dataset = datasets[0]if len(datasets) > 1:dataset = D.ConcatDataset(datasets)return [dataset]

4. 数据集目录管理 (DatasetCatalog)

为了集中管理数据集的路径和相关信息,我们定义了 DatasetCatalog 类。这个类包含了所有数据集的配置信息,并提供了一个静态方法 get 来获取特定数据集的配置信息。

# paths_catalog.py
import osclass DatasetCatalog(object):DATA_DIR = "/home/Public_DataSets"DATASETS = {"coco_2017_train": {"img_dir": "coco/train2017","ann_file": "coco/annotations/instances_train2017.json"},"voc_2007_train": {"data_dir": "voc/VOC2007","split": "train"},# ... 其他数据集配置 ...}@staticmethoddef get(name):if "coco" in name:data_dir = DatasetCatalog.DATA_DIRattrs = DatasetCatalog.DATASETS[name]args = dict(root=os.path.join(data_dir, attrs["img_dir"]),ann_file=os.path.join(data_dir, attrs["ann_file"]),)return dict(factory="COCODataset",args=args,)elif "voc" in name:data_dir = DatasetCatalog.DATA_DIRattrs = DatasetCatalog.DATASETS[name]args = dict(data_dir=os.path.join(data_dir, attrs["data_dir"]),split=attrs["split"],)return dict(factory="PascalVOCDataset",args=args,)# ... 其他数据集配置 ...raise RuntimeError("Dataset not available: {}".format(name))
说明

get 方法中,我们根据数据集名称动态生成配置字典。例如,对于 COCO 数据集:

return dict(factory="COCODataset",args=args,
)
  • factory:指定数据集类的名称,在后续步骤中用于动态获取工厂方法。
  • args:包含构建数据集实例所需的参数。

5. COCO 数据集的举例说明

假设我们有一个名为 "coco_2017_train" 的数据集,我们希望使用 DatasetCatalog 和工厂方法来加载这个数据集。以下是具体的步骤:

  1. 定义数据集配置

    # paths_catalog.py 中的 DATASETS 字典
    DATASETS = {"coco_2017_train": {"img_dir": "coco/train2017","ann_file": "coco/annotations/instances_train2017.json"},# ... 其他数据集配置 ...
    }
    
  2. 获取数据集配置

    data = DatasetCatalog.get("coco_2017_train")
    
  3. 动态获取工厂方法

    factory = getattr(D, data["factory"])
    
  4. 创建数据集实例

    args = data["args"]
    args["transforms"] = some_transform_function  # 假设我们有一个变换函数
    dataset = factory(**args)
    

通过这种方式,我们可以动态地加载 COCO 数据集,而无需硬编码每种数据集的构建逻辑。这种设计模式提高了代码的灵活性和可维护性,使得数据集的管理和加载更加方便。

附录: COCODataset 类完整实现
# datasets/coco.py
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypointsmin_keypoints_per_image = 10def has_valid_annotation(anno):if len(anno) == 0:return Falseif all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno):return Falseif "keypoints" not in anno[0]:return Trueif sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) >= min_keypoints_per_image:return Truereturn Falseclass COCODataset(torchvision.datasets.coco.CocoDetection):def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):super(COCODataset, self).__init__(root, ann_file)self.ids = sorted(self.ids)if remove_images_without_annotations:self.ids = [img_id for img_id in self.ids if has_valid_annotation(self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)))]self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()}self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}self._transforms = transformsdef __getitem__(self, idx):img, anno = super(COCODataset, self).__getitem__(idx)anno = [obj for obj in anno if obj["iscrowd"] == 0]boxes = [obj["bbox"] for obj in anno]boxes = torch.as_tensor(boxes).reshape(-1, 4)target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")classes = torch.tensor([self.json_category_id_to_contiguous_id[obj["category_id"]] for obj in anno])target.add_field("labels", classes)if anno and "segmentation" in anno[0]:masks = SegmentationMask([obj["segmentation"] for obj in anno], img.size, mode='poly')target.add_field("masks", masks)if anno and "keypoints" in anno[0]:keypoints = PersonKeypoints([obj["keypoints"] for obj in anno], img.size)target.add_field("keypoints", keypoints)target = target.clip_to_image(remove_empty=True)if self._transforms is not None:img, target = self._transforms(img, target)return img, target, idxdef get_img_info(self, index):return self.coco.imgs[self.id_to_img_map[index]]
  • __init__ 方法:初始化数据集,加载注释,过滤无效注释,并设置类别和图像映射。
  • __getitem__ 方法:获取指定索引的图像和注释,应用可选的变换,并返回图像、目标和索引。

相关文章:

  • RHEL8 配置epel源
  • 深入探讨C语言中的高级指针操作
  • 生产环境中MapReduce的最佳实践
  • 微信小程序在不同移动设备上的差异导致原因
  • Startup-SBOM:一款针对RPM和APT数据库的逆向安全工具
  • Jenkins使用Publish Over SSH插件远程部署程序到阿里云服务器
  • vue3+ts+vite+pinia+element-plus搭建一个项目
  • 使用Docker-compose一键部署Wordpress平台
  • Bean对象生命周期流程图
  • Compose(2)声明式UI
  • 简简单单用用perf
  • Shell运算符
  • CDD数据库文件制作(五)——服务配置(0x19_DTC Code)
  • 基于深度学习的图像特征优化识别复杂环境中的果蔬【多种模型切换】
  • leetcode 41-50(2024.08.19)
  • axios 和 cookie 的那些事
  • Babel配置的不完全指南
  • Docker 笔记(1):介绍、镜像、容器及其基本操作
  • ES6系列(二)变量的解构赋值
  • idea + plantuml 画流程图
  • Koa2 之文件上传下载
  • Laravel 实践之路: 数据库迁移与数据填充
  • laravel5.5 视图共享数据
  • Leetcode 27 Remove Element
  • Mac 鼠须管 Rime 输入法 安装五笔输入法 教程
  • Netty源码解析1-Buffer
  • Vue小说阅读器(仿追书神器)
  • 多线程 start 和 run 方法到底有什么区别?
  • 分类模型——Logistics Regression
  • 关于 Cirru Editor 存储格式
  • 前端设计模式
  • 人脸识别最新开发经验demo
  • 思否第一天
  • elasticsearch-head插件安装
  • Linux权限管理(week1_day5)--技术流ken
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ​ubuntu下安装kvm虚拟机
  • ​你们这样子,耽误我的工作进度怎么办?
  • ​软考-高级-信息系统项目管理师教程 第四版【第19章-配置与变更管理-思维导图】​
  • ###C语言程序设计-----C语言学习(3)#
  • $分析了六十多年间100万字的政府工作报告,我看到了这样的变迁
  • (Charles)如何抓取手机http的报文
  • (Forward) Music Player: From UI Proposal to Code
  • (TipsTricks)用客户端模板精简JavaScript代码
  • (创新)基于VMD-CNN-BiLSTM的电力负荷预测—代码+数据
  • (分享)一个图片添加水印的小demo的页面,可自定义样式
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (四)模仿学习-完成后台管理页面查询
  • (未解决)jmeter报错之“请在微信客户端打开链接”
  • (转)memcache、redis缓存
  • .bat批处理(四):路径相关%cd%和%~dp0的区别
  • .Net 6.0--通用帮助类--FileHelper
  • .net 8 发布了,试下微软最近强推的MAUI
  • .NET CF命令行调试器MDbg入门(四) Attaching to Processes
  • .Net Core和.Net Standard直观理解