基于detectron2框架的深度学习模型载入自定义数据集
基于detectron2框架的深度学习模型载入自定义数据集
一、前言
最近在做微光目标检测的研究工作,使用了Rank_DETR;这个模型是基于detrex框架,而detrex框架又是基于detectron2的。找了一圈没找到载入数据集的地方,后面查阅了资料得知要用API进行注册。
二、步骤
-
注册数据集:
在脚本中,我们首先要注册数据集。Detectron2 提供了多种注册数据集的方式,常用的是register_coco_instances
,用于 COCO 格式的数据集。您可以在脚本的开头或配置文件中添加如下代码来注册您的数据集:from detectron2.data.datasets import register_coco_instancesregister_coco_instances("my_dataset_train", {}, "path/to/train_annotations.json", "path/to/train_images/") register_coco_instances("my_dataset_val", {}, "path/to/val_annotations.json", "path/to/val_images/")
"my_dataset_train"
和"my_dataset_val"
是数据集的名称,您可以按需更改。path/to/train_annotations.json
和path/to/val_annotations.json
分别是训练和验证数据集的 COCO 格式标注文件路径。path/to/train_images/
和path/to/val_images/
是训练和验证图像的路径。
-
在配置文件中引用数据集:
在您使用的配置文件中,需要确保数据加载器 (dataloader
) 中引用了您刚才注册的数据集。通常,您需要修改以下内容:cfg.dataloader.train.dataset.names = "my_dataset_train" cfg.dataloader.test.dataset.names = "my_dataset_val"
这确保了训练和验证时使用的是您自定义的数据集。
三、示例代码集成
如果您已经在脚本中集成了以上步骤,代码可能如下所示:
def main(args):cfg = LazyConfig.load(args.config_file)cfg = LazyConfig.apply_overrides(cfg, args.opts)default_setup(cfg, args)register_coco_instances("exdark_train", {},"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_train2017.json","/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/train2017")register_coco_instances("exdark_test", {},"/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/annotations/instances_val2017.json","/liushuai2/PCP/datasets/Exdark-MAE/OwnerToCOCO/val2017")cfg.dataloader.train.dataset.names = "exdark_train"cfg.dataloader.test.dataset.names = "exdark_test"if args.eval_only:model = instantiate(cfg.model)model.to(cfg.train.device)model = create_ddp_model(model)DetectionCheckpointer(model).load(cfg.train.init_checkpoint)print(do_test(cfg, model))else:do_train(args, cfg)if __name__ == "__main__":parser = default_argument_parser()parser.add_argument("--use_wandb", action="store_true", help="Whether to use wandb.")parser.add_argument("--wandb_key", type=str, help="Wandb API key.")args = parser.parse_args()if args.use_wandb:wandb.login(key=args.wandb_key)launch(main,args.num_gpus,num_machines=args.num_machines,machine_rank=args.machine_rank,dist_url=args.dist_url,args=(args,),)