当前位置: 首页 > news >正文

MMDetection系列 | 5. MMDetection运行配置介绍


如有错误,恳请指出。


开门见山,基于mmdet的官方文档直接介绍如何进行我们的运行配置。个人觉得,继承于default_runtime.py这个文件之后,主要需要自己稍微更改下的配置主要有7个,分别是:优化器配置、学习率配置、工作流程配置、检查点配置、日志配置、评估配置、训练设置。具体的配置流程如下所示。

如果需要其他钩子函数的实现与配置,具体可以查看参考资料1.

文章目录

  • 1. 优化器配置
  • 2. 学习率配置
  • 3. 工作流程配置
  • 4. 检查点配置
  • 5. 日志配置
  • 6. 评估配置
  • 7. 训练设置

1. 优化器配置

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
  • 使用梯度剪辑来稳定训练
optimizer_config = dict(
    _delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

其中,_delete_=True将用新键替换backbone字段中的所有旧键


2. 学习率配置

lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[8, 11])       # 表示初始学习率在第8和11个epoch衰减10倍

还有其他的配置方案:

  • Poly schedule
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
  • ConsineAnnealing schedule
lr_config = dict(
    policy='CosineAnnealing',
    warmup='linear',
    warmup_iters=1000,
    warmup_ratio=1.0 / 10,
    min_lr_ratio=1e-5)
  • 使用动量调度加速模型收敛

支持动量调度器根据学习率修改模型的动量,这可以使模型以更快的方式收敛。Momentum 调度器通常与 LR 调度器一起使用

lr_config = dict(
    policy='cyclic',
    target_ratio=(10, 1e-4),
    cyclic_times=1,
    step_ratio_up=0.4,
)
momentum_config = dict(
    policy='cyclic',
    target_ratio=(0.85 / 0.95, 1),
    cyclic_times=1,
    step_ratio_up=0.4,
)

3. 工作流程配置

工作流是 (phase, epochs) 的列表,用于指定运行顺序和时期。默认情况下,它设置为:

workflow = [('train', 1)]

这意味着运行 1 个 epoch 进行训练。有时用户可能想要检查验证集上模型的一些指标(例如损失、准确性)。在这种情况下,我们可以将工作流设置为

[('train', 1), ('val', 1)]

这样 1 个 epoch 的训练和 1 个 epoch 的验证将被迭代运行。而验证集的损失同样会被计算出来。如果想先进行验证,再进行训练,还可以设置如下:

[('val', 1), ('train', n)]

这样设置表示先对验证集进行验证与损失计算,再进行n个epoch的计算。


4. 检查点配置

checkpoint_config = dict(interval=20)          # 20个epoch保存一次权重

参数说明见https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.CheckpointHook

CLASSmmcv.runner.CheckpointHook(interval: int = - 1, by_epoch: bool = True, save_optimizer: bool = True, out_dir: Optional[str] = None, max_keep_ckpts: int = - 1, save_last: bool = True, sync_buffer: bool = False, file_client_args: Optional[dict] = None, **kwargs)

  • interval (int) – The saving period. If by_epoch=True, interval indicates epochs, otherwise it indicates iterations. Default: -1, which means “never”.
  • by_epoch (bool) – Saving checkpoints by epoch or by iteration. Default: True.
  • save_optimizer (bool) – Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True.
  • out_dir (str, optional) – The root directory to save checkpoints. If not specified, runner.work_dir will be used by default. If specified, the out_dir will be the concatenation of out_dir and the last level directory of runner.work_dir. Changed in version 1.3.16.
  • max_keep_ckpts (int, optional) – The maximum checkpoints to keep. In some cases we want only the latest few checkpoints and would like to delete old ones to save the disk space. Default: -1, which means unlimited.
  • save_last (bool, optional) – Whether to force the last checkpoint to be saved regardless of interval. Default: True.
  • sync_buffer (bool, optional) – Whether to synchronize buffers in different gpus. Default: False.
  • file_client_args (dict, optional) – Arguments to instantiate a FileClient. See mmcv.fileio.FileClient for details. Default: None. New in version 1.3.16.

5. 日志配置

包装多个记录器log_config挂钩并允许设置间隔。现在 MMCV 支持WandbLoggerHookMlflowLoggerHookTensorboardLoggerHook.

log_config = dict(
    interval=50,    # 每500个迭代就打印一次训练信息
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])

参数说明见https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.EvalHook

CLASSmmcv.runner.LoggerHook(interval: int = 10, ignore_last: bool = True, reset_flag: bool = False, by_epoch: bool = True)[SOURCE]

  • interval (int) – Logging interval (every k iterations). Default 10.
  • ignore_last (bool) – Ignore the log of last iterations in each epoch if less than interval. Default True.
  • reset_flag (bool) – Whether to clear the output buffer after logging. Default False.
  • by_epoch (bool) – Whether EpochBasedRunner is used. Default True.

6. 评估配置

配置的evaluation将用于初始化EvalHook. 除了 key interval,其他参数如metric将传递给dataset.evaluate()
evaluation = dict(interval=1, metric=‘bbox’)

参数说明https://mmcv.readthedocs.io/en/latest/api.html?highlight=EpochBasedRunner#mmcv.runner.EpochBasedRunner

mmcv.runner.EvalHook(dataloader: torch.utils.data.dataloader.DataLoader, start: Optional[int] = None, interval: int = 1, by_epoch: bool = True, save_best: Optional[str] = None, rule: Optional[str] = None, test_fn: Optional[Callable] = None, greater_keys: Optional[List[str]] = None, less_keys: Optional[List[str]] = None, out_dir: Optional[str] = None, file_client_args: Optional[dict] = None, **eval_kwargs)

  • dataloader (DataLoader) – A PyTorch dataloader, whose dataset has implemented evaluate function.
  • start (int | None, optional) – Evaluation starting epoch. It enables evaluation before the training starts if start <= the resuming epoch. If None, whether to evaluate is merely decided by interval. Default: None.
  • interval (int) – Evaluation interval. Default: 1.
  • by_epoch (bool) – Determine perform evaluation by epoch or by iteration. If set to True, it will perform by epoch. Otherwise, by iteration. Default: True.
  • save_best (str, optional) – If a metric is specified, it would measure the best checkpoint during evaluation. The information about best checkpoint would be saved in runner.meta[‘hook_msgs’] to keep best score value and best checkpoint path, which will be also loaded when resume checkpoint. Options are the evaluation metrics on the test dataset. e.g., bbox_mAP, segm_mAP for bbox detection and instance segmentation. AR@100 for proposal recall. If save_best is auto, the first key of the returned OrderedDict result will be used. Default: None.
  • rule (str | None, optional) – Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as ‘acc’, ‘top’ .etc will be inferred by ‘greater’ rule. Keys contain ‘loss’ will be inferred by ‘less’ rule. Options are ‘greater’, ‘less’, None. Default: None.
  • test_fn (callable, optional) – test a model with samples from a dataloader, and return the test results. If None, the default test function mmcv.engine.single_gpu_test will be used. (default: None)
  • greater_keys (List[str] | None, optional) – Metric keys that will be inferred by ‘greater’ comparison rule. If None, _default_greater_keys will be used. (default: None)
  • less_keys (List[str] | None, optional) – Metric keys that will be inferred by ‘less’ comparison rule. If None, _default_less_keys will be used. (default: None)
  • out_dir (str, optional) – The root directory to save checkpoints. If not specified, runner.work_dir will be used by default. If specified, the out_dir will be the concatenation of out_dir and the last level directory of runner.work_dir. New in version 1.3.16.
  • file_client_args (dict) – Arguments to instantiate a FileClient. See mmcv.fileio.FileClient for details. Default: None. New in version 1.3.16.
  • **eval_kwargs – Evaluation arguments fed into the evaluate function of the dataset.

7. 训练设置

runner = dict(type='EpochBasedRunner', max_epochs=150)   # 设置模型训练多少次

参数说明https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.EpochBasedRunner

mmcv.runner.EpochBasedRunner(model: torch.nn.modules.module.Module, batch_processor: Optional[Callable] = None, optimizer: Optional[Union[Dict, torch.optim.optimizer.Optimizer]] = None, work_dir: Optional[str] = None, logger: Optional[logging.Logger] = None, meta: Optional[Dict] = None, max_iters: Optional[int] = None, max_epochs: Optional[int] = None)


总结:

一般来说,我们写配置文件都会继承default_runtime.py这个文件

_base_ = [
    '../_base_/default_runtime.py'
]

这个文件的内容如下所示:

checkpoint_config = dict(interval=5)    # 每5个epoch保存一次权重
# yapf:disable
log_config = dict(
    interval=50,    # 每500个迭代就打印一次训练信息
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None        # 加载权重文件
resume_from = None
workflow = [('train', 1)]

# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'

# Default setting for scaling LR automatically
#   - `enable` means enable scaling LR automatically
#       or not by default.
#   - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)

一般不需要更改太多的内容,可以时代的更改log_config进行合理的打印训练信息,还有设置checkpoint_config进行合理的保存权重文件,其他的设置按默认即可。

下面展示我继承了default_runtime.py后更改的内容,其实就是更改了以上我所介绍的七点内容:

_base_ = [
    '../_base_/default_runtime.py'
]

......
# optimizer
optimizer = dict(   # 设置使用AdamW优化器(默认使用的是SGD)
    type='AdamW',
    lr=0.0001,
    weight_decay=0.0001,
    paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))

evaluation = dict(interval=5, metric='bbox')   # 5个epoch验证一次
optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2))  # 设置梯度裁剪(default_runtime.py中默认为None)
checkpoint_config = dict(interval=20)          # 20个epoch保存一次权重
log_config = dict(interval=50,     # 每50次迭代训练就打印一次信息(注意是迭代而不是epoch)
                  hooks=[dict(type='TextLoggerHook')])

# learning policy
lr_config = dict(policy='step', step=[100])              # 学习率在100个epoch进行衰减
runner = dict(type='EpochBasedRunner', max_epochs=150)   # 训练150个epoch

参考资料:

1. Customize Runtime Settings

2. mmcv官方文档

相关文章:

  • java实现顺序表
  • 【英语:基础进阶_核心词汇扩充】E5.常见词根拓词
  • 命令执行漏洞——系统命令执行
  • 【数据结构与算法】List接口栈队列
  • 将cookie字符串转成editthiscookie插件的json格式
  • SpringAOP总结
  • python--数据容器--列表
  • Roson的Qt之旅 #119 QNetworkAddressEntry详细介绍
  • Mybatis -- 使用
  • C语言双链表,循环链表,静态链表讲解(王道版)
  • 比较zab、paxos和raft的算法的异同
  • Python Argparse 库讲解特别好的
  • C++~从编译链接的过程看为什么C++支持重载?externC有什么用?
  • App移动端测试【10】Monkey自定义脚本案例
  • springboot 整合dubbo3开发rest应用
  • IE9 : DOM Exception: INVALID_CHARACTER_ERR (5)
  • [ 一起学React系列 -- 8 ] React中的文件上传
  • Effective Java 笔记(一)
  • in typeof instanceof ===这些运算符有什么作用
  • IndexedDB
  • MySQL常见的两种存储引擎:MyISAM与InnoDB的爱恨情仇
  • mysql外键的使用
  • PhantomJS 安装
  • ReactNative开发常用的三方模块
  • redis学习笔记(三):列表、集合、有序集合
  • 不用申请服务号就可以开发微信支付/支付宝/QQ钱包支付!附:直接可用的代码+demo...
  • 从tcpdump抓包看TCP/IP协议
  • 从零开始在ubuntu上搭建node开发环境
  • 动态魔术使用DBMS_SQL
  • 前端面试之闭包
  • 我与Jetbrains的这些年
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • 智能合约开发环境搭建及Hello World合约
  • 关于Android全面屏虚拟导航栏的适配总结
  • # .NET Framework中使用命名管道进行进程间通信
  • #define用法
  • #if和#ifdef区别
  • #include
  • ( 10 )MySQL中的外键
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (Redis使用系列) Springboot 使用redis实现接口Api限流 十
  • (七)Knockout 创建自定义绑定
  • (一)基于IDEA的JAVA基础10
  • (转)GCC在C语言中内嵌汇编 asm __volatile__
  • (转)关于如何学好游戏3D引擎编程的一些经验
  • **PHP分步表单提交思路(分页表单提交)
  • **PyTorch月学习计划 - 第一周;第6-7天: 自动梯度(Autograd)**
  • .\OBJ\test1.axf: Error: L6230W: Ignoring --entry command. Cannot find argumen 'Reset_Handler'
  • .form文件_SSM框架文件上传篇
  • .NET 回调、接口回调、 委托
  • .NET/C# 反射的的性能数据,以及高性能开发建议(反射获取 Attribute 和反射调用方法)
  • .NET/C# 利用 Walterlv.WeakEvents 高性能地中转一个自定义的弱事件(可让任意 CLR 事件成为弱事件)
  • .NET成年了,然后呢?
  • ::前边啥也没有
  • @AutoConfigurationPackage的使用