当前位置: 首页 > news >正文

rcnn代码实现_Detectron2代码学习2 -- 检测模型实现

合集目录:

  1. 庞子奇:Detectron2 代码学习 1 -- 整体结构

2. Detectron2 代码学习 2 -- 检测模型实现 (本篇)

3. 庞子奇:Detectron2代码学习3 -- 数据加载

4. (计划中) 模型实现细节与其它设计

2. Detectron2Model部分

在这部分中我们将沿着刚才分析的训练结构,尝试分析如何构建detectron2的模型。为了实现这一点,我们会首先介绍detectron2中使用的registry机制,之后进一步分析

2.1 Registry机制与build_model

Trainer中初始化模型调用的接口是build_model函数,通过modeling/__init__.py可以知道它是在modeling/meta_arch/build.py中定义的。但是在阅读build.py的过程中,我们发现它使用了一个叫做Registry的东西——那么什么是Registry呢?

Registry机制来自于FaceBook计算机视觉研究组的常用函数库fvcore,其中Registry的源代码和解读可见registry。它的主要作用是提供了用字符串调用类方法的接口,具体的函数是registry中的get方法。个人感觉registry.get()非常像一个getattr方法,它能够通过字符串访问类的方法。

接下来我们通过build_model的具体的用例来体会一下Registry机制的使用。例如我们有一个具体的网络结构,定义在modeling/meta_arch/rcnn.py中的GeneralizedRCNN。注意在它的定义上方有一个修饰器@META_ARCH_REGISTRY.register(),意思就是把GeneralizedRCNN注册到META_ARCH_REGISTRY中。那么在modeling/meta_arch/build.py中的model = META_ARCH_REGISTRY.get(meta_arch)(cfg)中,只要我们设置meta_archGeneralizedRCNN,那么META_ARCH_REGISTRY.get(meta_arch)就调用了GeneralizedRCNN方法,也就因此初始化了模型。

最后我们回顾和总结一下registry的几个要点

  • 创建一个registry,设置可能的调用的类的名字
  • 对于想加入到registry中的类,在定义的时候通过修饰器指定
  • 调用类方法的时候通过get(name)
  • 通过fvcore中的源代码,我们可以知道如何写python修饰器。

由此,我们弄清楚了build_model的实现和registry的机制。

2.2 RCNN结构模型

在分析模型建构的过程中,我们主要看经典模型Faster-RCNN,在modeling/meta_arch/rcnn.pyFaster-RCNN的结构主要包含如下几个部分,而这几个部分也是GeneralizedRCNN在初始化部分的输入:

  • backbone:从图片提取特征表示的卷积神经网络结构,比如说ResNet。
  • proposal_generator:从图片的特征预测“哪里可能有物体”
  • roi_head:以proposal_generator部分预测的有物体区域为基础,预测物体的类别和检测框坐标

__init__功能相似的是from_config函数,它用config文件初始化模型,通过这个函数我们知道了实际构造backbone、proposal_generator和roi_head的函数分别都是来自对应文件夹的build_xxx函数。

在这之后我们继续考察RCNN模型处理函数的过程,我们主要关注forward函数,主要包括如下步骤:

  • 利用preprocess_image对输入的图片做初始化,特别是其中会对图片做归一化(normalization),并将图片放到指定的设备(device,也就是gpu)上。
  • self.backbone(images.tensor)一句,我们可以得到这个图片经过卷积神经网络处理得到的feature
  • 如果我们没有告诉rcnn已经被发现的物体有哪些(detected_instances is None),它就需要自己寻找哪些位置可能有物体出现,并通过self.proposal_generator(images, features)得到可能出现物体的位置proposals
  • 得到了可能出现物体的区域,我们就用self.roi_heads(images, features, proposals),结合feature的信息对roi区域进行处理,得到每个区域的结果。值得注意的是,如果我们提供了目标结果(gt_instances),那么roi_head计算的就是Loss了。

综上,我们知道了RCNN的架构和运行方式,接下来我们会针对它的组件,backbone、proposal_generator和roi_head分别进行研究。

2.3 backbone

backbone的类仍然是采用registry的模式,这样方便后继的人对于库进行修改。对于backbone的分析其实不用非常具体,因为BackBone就是最普通的卷积神经网络。但是在detectron2的具体实现中,有下面几点可以简单提一下:

  • detectron2的backbone均基于modeling/backbone/backbone.py中的Backbone基类,它在nn.Module上做了简单的包装,主要是针对物体检测和实例分割的特殊需求,比如说:
    • size_divisibility:因为卷积神经网络涉及到降采样,所以会对输入数据产生要求。例如降采样8倍的网络,其输入的长宽也必须是8的倍数。
    • output_shape:返回输出的Feature Map的形状是怎样的,方便后续的proposal_generator和roi_head进行处理。
  • ResNet的实现为例,它的实现方法和torchvision中对ResNet的实现方式非常像(或者说基本是一样的),只是在build_resnet中制作了生成resnet模型的统一接口。

2.4 proposal_generator

backbone相同,proposal_generator也是通过build_proposal_generator的接口创建“从图片的特征图提出可能产生物体的区域”的网络。暂时detectron2实现了两种proposal_generator,一种是Region Proposal Generator,也就是代码库中的rpn;另一种是Rotated Region Proposal Network,也就是rrpn。在下面,我们主要分析RPN的实现。

2.4.1 RPN Head

RPN作为proposal_generator的一种并不是最细分的抽象层次,最细分的抽象层次是RPN Head,也就是RPN用怎样的结构处理输入的特征图。这也就是在modeling/proposal_generator/rpn.pyRPN_HEAD_REGISTRY的由来。

在这里,detectron2为我们实现了一个典型的RPN Head——StandardRPNHead。它首先用一个3x3的卷积处理输入的特征图,之后分别用两个1x1的卷积处理得到:objectness_logits,有多大可能是一个物体;anchor_deltas:如何生成一个合理的锚定框(Anchor Box)。在它的输出中,objectness_logits是一个的张量,其中是每个Batch的图片个数,是在每个空间位置上设定的锚定框个数;anchor_deltasobjectness_logits是一一对应的,只不过它的输出变成了,增加了对锚定框位置的描述。

2.4.2 RPNforward过程

尽管在上一段中我们清楚了RPN Head是如何构成的、它的工作原理是怎样的,但是我们对于RPN最关键的部分:它是如何预测锚定框的、它又是如何训练的仍然一无所知。为了实现这个目的,我们需要仔细研究RPN的实现,最主要的就是它forward的过程。

RPN的输入是图片的特征,它通过如下步骤进行处理:

  • 首先用anchor_generator生成了anchor的所有可能位置。我们会在稍后详细分析anchor_generator的实现,在这里我们可以把它的输出暂时理解为“一个列表,列表包含了几何意义上的所有可能锚定框”。
  • 其次通过刚刚分析过的rpn_head为每个anchor的位置预测了存在物体的可能性和框的可能位置。
  • 如果在输入中提供了真实的锚定框信息,也就是gt_instances不是None,那么会通过self.label_and_sample_anchors制作属于“锚定框”的数据集,并由self.losses计算RPN部分的损失函数,以训练RPN Head
  • 通过self.predict_proposals可以通过之前得到的每个位置的:是否有物体、回归的框位置,得到可能存在物体的区域,也就是proposals,至此,实现了RPN处理图片的forward过程。

接下来,我们关注下面几个部分的实现,分别是self.label_and_sample_anchorsself.lossesself.predict_proposals

2.4.3 label_and_sample_anchors:处理真实框信息

RPN Head的目标是对anchor_generator中提出的粗糙的锚定框进行分类,选取出真正可能存在物体的框,放弃不存在物体的框。那么为了训练RPN Head实现这一功能,我们就需要通过已知的物体框信息训练RPN Head对粗糙框进行分类:哪些是正样本(包含物体)、哪些是负样本(不包含物体)等等。而label_and_sample_anchors函数的目的就是制作一个训练RPN Head处理这类问题的数据集。

算法的主体思路是通过我们提出的框和实际框之间的IOU判定是正样本还是负样本。因此:

  • 首先计算每个提出框与真实框之间的IOU,存储到match_quality_matrix
  • 其次,通过modeling/matcher.py/Matcher为这些框之间打Label。这部分的操作就是,如果IOU低于某个阈值(例如0.3),那么Label为0,代表负样本;超过某个阈值(例如0.7),那么Label为1,代表正样本;在两个阈值之间是难以判断的情况,为了避免对训练造成混淆,设置Label为-1,代表训练时可以忽略的样本。
  • 根据每张图片中正样本的比例选取一部分正样本、负样本参与训练。其中正样本不应超过一个预设的比例(positive_fraction),剩余的样本数量用负样本补齐。
  • 返回的结果有两部分,第一部分是对应每个提出的锚定框的真实框坐标值,第二部分是这组对应关系的标签(Label)是怎样的。

这样,我们就得到了可以训练RPN Head的数据。

2.4.4 losses:计算损失函数

通过观察losses的参数列表,我们就可以对它的逻辑略知一二。它的主要思路应该是通过objectness_logitsgt_label(回忆一下,在2.4.3中我们介绍了gt_label的含义)使得网络可以分辨哪片区域是有物体的,通过deltas和真实物体的框gt_bboxes使得网络能够得到物体的详细位置。

在实际的实现中也基本是按照这两部分进行的。“是否有物体”的判定这明显是一个二分类问题,所以通过binary_cross_entropy就可以训练。

针对回归物体位置的部分,在losses中提供了两种选择:

  • 第一种是采用smooth_l1_loss,它的直观就是直接求预测值和实际值之间的差别作为Loss,它的做法就是直接求预测的框位置的Delta和真实框的Delta之间的差别。
  • 第二种方式是采用Giou Loss,它的做法和第一种做法,即计算delta的差距,是相反的。它的直观是预测的Delta产生的框效果是怎样的,所以它会首先把预测的Delta转化成具体框的位置,之后与真正的框位置计算iou,那么iou越大说明预测的位置越准确。

综上,我们了解了在RPN部分的Loss计算。

2.4.5 predict_proposal:提取Proposal

在这部分中,我们将来到RPN的最后一步,即如何得到物体的Proposal,利用用之前预测的objectness——是否有物体,和deltas——物体的位置?在这里我们重申一下得到Proposal的目的:为了细致地判断框内物体的类别和位置。

predict_proposal中实现了提取proposal的全过程,主要包含如下步骤:

  • 通过_decode_proposal将预测的deltas转化为实际的框位置
  • 通过find_top_rpn_proposals找到其中最合适的预测结果

在这里,我们只对find_top_rpn_proposals进行讨论,它的实现在modeling/proposal_generator/proposal_utils.py。这个算法实际上就是实现了一个简单的筛选过程。它首先按照objectness抽取出分数最高的部分Proposal,值得注意的是,这部分的Proposal是足够多的,用机器学习的术语来解释就是“有很高的Recall”。那么这些Proposal中又如何进一步筛选呢?

在筛选的过程中,最重要的就是NMS操作。在之前的Proposal中,存在很多框严重重叠的现象,这是因为在最开始提出Proposal的过程中,在同一个位置会有多个候选框出现,而这些候选框实际上代表的都是同一个物体。因此NMS操作应运而生,它的目的就是通过计算框之间的iou值,筛选掉这些实际上代表了同一物体的候选框。这部分的具体实现在layers/nms.py。在后面我们会对它进行具体分析。

经过上面的操作,我们就可以预测合理的Proposal,也就因此实现了对于proposal_generator模块的理解。

2.4.6 小结与讨论

在这部分中我们不关心detectron2实现的具体细节,而是针对它的设计思想、设计逻辑进行一些总结和讨论。在detectron2的实现中,它首先将RPN本身的输入输出要弄清楚:输入是来自Backbone的特征表示,输出是哪些地点存在物体的Proposal

其次,它将RPN按照功能分拆成如下多个部分,按照“从粗糙到细致”的顺序逐步完成了对Proposal的提取,从最开始的只要是一个位置就看做一个Proposal,到最终利用分数和IOU筛选出合理的Proposal。跳出RPN本身的实现,如果我们不考虑后面逐步筛选的步骤,那么只要我们有足够多的计算资源,其实是不会影响后面的训练效果的——因为我们完全可以把那些不包含物体的框设置成负样本不予考虑。这样的角度给我们的启示就是,其实我们完全可以重新对RPN的操作进行设计,增加或者减少对于Proposal的筛选,当然也可以改变筛选的方式,比如NMS的做法。不过无论我们怎么修改,其实都是在按照detectron2一种“链式筛选”的结构在做。

最后,毫无疑问,detectron2对于RPN在这样的实现方式给予了编程人员很大的修改自由度和简洁的抽象层次。尽管我们可能只有完整地了解了RPN的流程才会想到这样的实现方法,但是我们可以在未来设计算法的时候抓住“链式筛选”这样“由粗糙到细致”的算法模式,设计自己的接口和实现方式。

2.5 roi_heads

2.5.1 结构简介

回顾rcnn的结构,对roi_heads的调用出现的形式是results, _ = self.roi_heads(images, features, proposals, None),也就是roi_heads处理Proposal,得到最终检测的结果;如果我们已经知道了proposal是怎样的,那么通过roi_heads_forward_with_given_boxes也可以实现检测。在了解了roi_heads的功能之后,我们来看它的实现,基本都是在modeling/roi_heads/里面。

roi_head__init__.py中可以看到detectron2提供了对多种任务的roi_head支持,例如box、mask和key_point,分别针对物体检测、实例分割和人体姿态估计。在我们的分析中,仅以物体检测的Box_head为例进行分析。构建roi_head的函数是在modeling/roi_heads/roi_heads.pybuild_roi_heads。通过实现可以看到它也是通过REGISTRY方式进行构建。通过查阅config/defaults.py文件,我们知道了对于默认的Faster-RCNN使用的是Res5ROIHeads,所以我们重点研究它的实现。

2.5.2 ROIHeads基类:如何提供对ROI的训练数据

在注释中作者解释了ROIHeads共有的逻辑:

  • 在训练部分,在Proposal和实际框之间进行匹配,同时对Proposal进行采样
  • 对Proposal的一片区域进行处理,得到该区域的特征
  • 利用Feature针对我们的任务(检测/分割/...)进行预测

ROIHeads基类中并没有对forward的逻辑进行实现,因为它与具体的任务相关,在BoxHeadMaskHead等模块内部自己实现。在基类中具体实现的是如何通过Ground Truth的样本和Proposal样本的列表创造对ROIHead的训练数据。通过对Res5ROIHeads的阅读可以发现实际得到使用的函数是label_and_sample_proposals

首先,通过add_ground_truth_to_proposals可以提升Proposal的效果,特别是在训练刚刚开始的时候,RPN提供的Proposal可能质量很差,那么利用Ground Truth的数据可以保证正样本的存在。

其次,我们将Proposal和实际的框之间进行匹配,利用的函数是我们在RPNlabel_and_sample_anchors中介绍的一样的方法。我们会计算每个Proposal与真实的框之间的IOU矩阵,通过proposal_matcher得到两两之间匹配的结果。进一步地,通过sample_proposals我们从所有的Proposal中提取出参与训练的Proposal。至此就基本实现了获取训练ROI的Proposal的步骤。

最后我们需要看一下sample_proposals的实现中需要注意的一个点。它利用subsample_labels从proposals中提出一定数量的样本训练,其中正样本不能超过一定的比例。关于负样本究竟具体会对训练结果产生怎样的影响,可以参考RCNN实验里的讨论。

2.5.3 Res5ROIHeads

Res5ROIHeads是默认使用的ROIHead,在它的forward函数中实现了对Proposal进行框预测和类别预测的过程。它(训练的)的主要流程是先通过对Proposal的采样制造训练数据。其次通过shared_roi_transform提取每个Proposal框中的特征表示。最后通过box_predictor得到每个Proposal内物体的类别和具体的框坐标。

其中,我们之前已经分析了如何在Proposal中采样得到训练样本,因此不再讨论。shared_roi_transform中主要用到了Roi_AlignRoi_Pool等操作,它们的重点在于这些模块的实现,因此我们在后面再进行讨论。在这里我们只需要知道ROIPool的功能就是:对于给定区域内的图片特征(Feature Map),我们将其变成形状一定的这个Proposal的特征表示(Feature Vector)。在最后一步中,detectron2通过box_predictor利用每个Proposal内的特征表示预测物体的类别和准确的框坐标。在下面,我们简要分析box_predictor的工作过程。

box_predictor的主要流程是现在modeling/roi_heads/fast_rcnn.py/FastRCNNOutputLayers中。通过对__init__.py的阅读,我们可以发现box_predictor的核心在于两个线性层:cls_scorebox_pred,分别负责预测物体的类别和回归框的坐标。它的工作流程在forward函数中其实相当简单,只需要把特征表示分别送入到cls_scorebox_pred中,两个线性层就会分别把Logits和预测的框的Deltas输出出来。

综上,我们也基本了解了roi_heads的工作过程。

3. 成库

相关文章:

  • ransac算法_无人驾驶算法学习(一):激光里程计之帧间匹配算法
  • java安装步骤_jmeter安装及环境配置(一)
  • python xlsx读写_Python Excel文件的读写操作(xlwt xlrd xlsxwriter)
  • python 折线图标签_matplotlib 曲线图 和 折线图 plt.plot()实例
  • matlab eig函数_心心念念的matlab基础及入门来啦!
  • python与html结合_Python在字符串中处理html和xml的方法
  • 怎么下载安装python_【转】如何下载安装python
  • eclipse maven打包jar_Maven 异常信息:jar包缺失或损坏,导致编译、打包错误
  • python indexerror_Python 未超索引情况下 显示 IndexError
  • calendar round_java实战项目常用类,Date、Calendar、BigDecimal、Math、UUID
  • 新代系统plc梯形图说明书_PLC现场实例电气原理图及编程
  • python注释是什么意思_python注释是什么意思
  • 如何和后台接触的_民熔小课堂|跌落式熔断器该如何检修?点进来告诉你答案!...
  • python rgb库_Python实现RGB与HSI颜色空间的互换方式
  • python删除对象引用_Python:删除自引用对象
  • 《Javascript高级程序设计 (第三版)》第五章 引用类型
  • 【技术性】Search知识
  • 2019.2.20 c++ 知识梳理
  • axios 和 cookie 的那些事
  • C++回声服务器_9-epoll边缘触发模式版本服务器
  • CentOS7简单部署NFS
  • ComponentOne 2017 V2版本正式发布
  • Fundebug计费标准解释:事件数是如何定义的?
  • Go 语言编译器的 //go: 详解
  • Hibernate【inverse和cascade属性】知识要点
  • Intervention/image 图片处理扩展包的安装和使用
  • MYSQL如何对数据进行自动化升级--以如果某数据表存在并且某字段不存在时则执行更新操作为例...
  • PHP面试之三:MySQL数据库
  • Redis 懒删除(lazy free)简史
  • 阿里云爬虫风险管理产品商业化,为云端流量保驾护航
  • 百度贴吧爬虫node+vue baidu_tieba_crawler
  • 力扣(LeetCode)21
  • 排序(1):冒泡排序
  • 设计模式(12)迭代器模式(讲解+应用)
  • 使用agvtool更改app version/build
  • # Swust 12th acm 邀请赛# [ E ] 01 String [题解]
  • #NOIP 2014# day.1 T3 飞扬的小鸟 bird
  • #调用传感器数据_Flink使用函数之监控传感器温度上升提醒
  • $ is not function   和JQUERY 命名 冲突的解说 Jquer问题 (
  • (2)Java 简介
  • (Git) gitignore基础使用
  • (附源码)计算机毕业设计大学生兼职系统
  • (汇总)os模块以及shutil模块对文件的操作
  • (紀錄)[ASP.NET MVC][jQuery]-2 純手工打造屬於自己的 jQuery GridView (含完整程式碼下載)...
  • (力扣记录)235. 二叉搜索树的最近公共祖先
  • (免费领源码)Python#MySQL图书馆管理系统071718-计算机毕业设计项目选题推荐
  • (转)JAVA中的堆栈
  • (转)关于如何学好游戏3D引擎编程的一些经验
  • .cfg\.dat\.mak(持续补充)
  • .Family_物联网
  • .NET Core日志内容详解,详解不同日志级别的区别和有关日志记录的实用工具和第三方库详解与示例
  • .NET MVC第三章、三种传值方式
  • .NET面试题解析(11)-SQL语言基础及数据库基本原理
  • .stream().map与.stream().flatMap的使用
  • /ThinkPHP/Library/Think/Storage/Driver/File.class.php  LINE: 48