当前位置: 首页 > news >正文

Pytorch机器学习(八)—— YOLOV5中NMS非极大值抑制与DIOU-NMS等改进

Pytorch机器学习(八)—— YOLOV5中NMS非极大值抑制与DIOU-NMS等改进

目录

Pytorch机器学习(八)—— YOLOV5中NMS非极大值抑制与DIOU-NMS等改进

前言

一、NMS非极大值抑制算法

二、Hard-NMS非极大值代码

三、DIOU-NMS

 四、soft-NMS


前言

在目标检测的预测阶段时,会输出许多候选的anchor box,其中有很多是明显重叠的预测边界框都围绕着同一个目标,这时候我就可以使用NMS来合并同一目标的类似边界框,或者说是保留这些边界框中最好的一个。

如果对IOU等知识不了解的可以看我上篇博客Pytorch机器学习(五)——目标检测中的损失函数(l2,IOU,GIOU,DIOU, CIOU)


一、NMS非极大值抑制算法

我们先看一下NMS的直观理解,左图为两个ground truth的bbox,右图为我自己模拟网络输出的预测框。

 而下图则是我使用Pytorch官方提供的NMS实现的非极大值抑制,可以看到经过NMS后预测框保留了效果最好的,去除了冗余的预测框。

 

 下面来讲讲NMS算法的流程,其实也是十分简单的

        一.从所有候选框中选取置信度最高的预测边界框B1作为基准,然后将所有与B1的IOU超过预定阈值的其他边界框移除。

(这时所有边界框中B1为置信度最高的边界框且没有和其太过相似的边界框——非极大值置信度的边界框被抑制了

        二.从所有候选框中选取置信度第二高的边界框B2作为一个基准,将所有与B2的IOU超过预定阈值的其他边界框移除。

        三.重复上述操作,直到所有预测框都被当做基准——这时候没有一对边界框过于相似

二、Hard-NMS非极大值代码

在YOLOV5的源码当中,作者是直接调用了Pytorch官方的NMS的API

在general.py中的non_max_suppression函数中

"""
其中boxes为Nx4的tensor,N为框的数量,4则为x1 y1 x2 y2
socres为N维的tensor,表示每个框的置信度
iou_thres则为上面算法中的IOU阈值
返回值为一个去除了过于相似框后的,根据置信度降序排列的列表,我们就可以根据此列表输出预测框
"""
i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

为了便于后续其他NMS的改进,这里我们也自己写一个NMS算法,这里借鉴了沐神的代码b站链接,大家可以直接在YOLOV5中把上面的torchvision.ops.nms更改为下面的NMS函数

def NMS(boxes, scores, iou_thres, GIoU=False, DIoU=False, CIoU=False):
    """

    :param boxes:  (Tensor[N, 4])): are expected to be in ``(x1, y1, x2, y2)
    :param scores: (Tensor[N]): scores for each one of the boxes
    :param iou_thres: discards all overlapping boxes with IoU > iou_threshold
    :return:keep (Tensor): int64 tensor with the indices
            of the elements that have been kept
            by NMS, sorted in decreasing order of scores
    """
    # 按conf从大到小排序
    B = torch.argsort(scores, dim=-1, descending=True)
    keep = []
    while B.numel() > 0:
        # 取出置信度最高的
        index = B[0]
        keep.append(index)
        if B.numel() == 1: break
        # 计算iou,根据需求可选择GIOU,DIOU,CIOU
        iou = bbox_iou(boxes[index, :], boxes[B[1:], :], GIoU=GIoU, DIoU=DIoU, CIoU=CIoU)
        # 找到符合阈值的下标
        inds = torch.nonzero(iou <= iou_thres).reshape(-1)
        B = B[inds + 1]
    return torch.tensor(keep)

这里的计算IOU的函数——bbox_iou则是直接引用了YOLOV5中的代码,其简洁的集成了对与GIOU,DIOU,CIOU的计算。

def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
    # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
    box2 = box2.T

    # Get the coordinates of bounding boxes
    if x1y1x2y2:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
    else:  # transform from xywh to xyxy
        b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
        b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
        b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
        b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2

    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
    union = w1 * h1 + w2 * h2 - inter + eps

    iou = inter / union
    if GIoU or DIoU or CIoU:
        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
                    (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center distance squared
            if DIoU:
                return iou - rho2 / c2  # DIoU
            elif CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / ((1 + eps) - iou + v)
                return iou - (rho2 / c2 + v * alpha)  # CIoU
        else:  # GIoU https://arxiv.org/pdf/1902.09630.pdf
            c_area = cw * ch + eps  # convex area
            return iou - (c_area - union) / c_area  # GIoU
    else:
        return iou  # IoU

三、DIOU-NMS

其实DIOU-NMS就是把我上面说的NMS算法中的IOU阈值改为DIOU,将NMS代码中的DIOU设置为True即可。

根据DIOU的论文,如果只是单纯的使用NMS,即是使用IOU作为阈值去筛掉其他预测框时,当两个物体过于接近时,很有可能另外一个物体的预测框就被滤除了。

就像下图中的摩托。使用DIOU-NMS可以一定程度上提升对于靠近的物体的检测。

 四、soft-NMS

网上还有一种soft-NMS的算法,其思想就是传统的NMS,如果只通过IOU值就将其他的框直接去掉,有可能会不妥,于是就引入了soft-NMS。

具体流程就是我们把NMS算法中去除其他边界框改成,修改其他边界框的置信度。

以下引一个博主的图

 其中的f()函数,现在都是使用的高斯函数

si即为置信度,M为当前最大置信度的边界框,bi为其他边界框

网上对此的效果看法也是褒贬不一,我自己也没有试过,但从直觉来说,我个人觉得效果不会有很大的提升,如果感兴趣的可以自己试一试。

相关文章:

  • 出道即封神的ChatGPT,现在怎么样了?
  • UDS 14229 -1 刷写34,36,37服务,标准加Trace讲解,没理由搞不明白
  • 朋友去华为面试,轻松拿到26K的Offer,羡慕了......
  • CPU平均负载高问题定位分析
  • 数据仓库相关概念的解释
  • 软测面试了一个00后,绝对能称为是内卷届的天花板
  • SpringBoot最常用的50个注解(全是干货,干的要死!)
  • TH-OCR文字识别SDK 12.X介绍
  • 滑动窗口算法
  • 【Kafka】MM2同步Kafka集群时如何自定义复制策略(ReplicationPolicy)
  • iOS 语言基础初探 Xcode 工具
  • 读书笔记——《富爸爸穷爸爸》
  • 【钓鱼实测】写bug给new bing和chatGPT查。问他们林黛玉倒拔垂杨柳
  • 蓝桥刷题总结1
  • 三天吃透MySQL面试八股文
  • (ckeditor+ckfinder用法)Jquery,js获取ckeditor值
  • 《网管员必读——网络组建》(第2版)电子课件下载
  • ABAP的include关键字,Java的import, C的include和C4C ABSL 的import比较
  • - C#编程大幅提高OUTLOOK的邮件搜索能力!
  • Effective Java 笔记(一)
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • iOS 系统授权开发
  • JS专题之继承
  • scala基础语法(二)
  • SSH 免密登录
  • ViewService——一种保证客户端与服务端同步的方法
  • 从零搭建Koa2 Server
  • 大快搜索数据爬虫技术实例安装教学篇
  • 代理模式
  • 浮动相关
  • 基于HAProxy的高性能缓存服务器nuster
  • 快速构建spring-cloud+sleuth+rabbit+ zipkin+es+kibana+grafana日志跟踪平台
  • 名企6年Java程序员的工作总结,写给在迷茫中的你!
  • 爬虫进阶 -- 神级程序员:让你的爬虫就像人类的用户行为!
  • 前端之React实战:创建跨平台的项目架构
  • 让你成为前端,后端或全栈开发程序员的进阶指南,一门学到老的技术
  • 微信小程序开发问题汇总
  • 如何通过报表单元格右键控制报表跳转到不同链接地址 ...
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • ​iOS安全加固方法及实现
  • ​决定德拉瓦州地区版图的关键历史事件
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • #我与Java虚拟机的故事#连载14:挑战高薪面试必看
  • (¥1011)-(一千零一拾一元整)输出
  • (1)(1.9) MSP (version 4.2)
  • (10)Linux冯诺依曼结构操作系统的再次理解
  • (12)Hive调优——count distinct去重优化
  • (poj1.2.1)1970(筛选法模拟)
  • (二)Pytorch快速搭建神经网络模型实现气温预测回归(代码+详细注解)
  • (排序详解之 堆排序)
  • (三)docker:Dockerfile构建容器运行jar包
  • (十六)Flask之蓝图
  • (转)大型网站的系统架构
  • (转)关于pipe()的详细解析
  • .net 生成二级域名