当前位置: 首页 > news >正文

嵌入式AI---如何用C++实现YOLO的NMS(非极大值抑制)算法

文章目录

  • 前言
  • 一、为什么需要NMS算法?
  • 二、什么是NMS算法?
  • 三、如何使用C++编写一个NMS算法
    • 1、预测框定义
    • 2、滤除无效框
  • 总结


前言

YOLO系列的目标检测算法在边缘部署方面展现出了强大的性能和广泛的应用潜力。大部分业务场景是利用PyTorch在服务器端完成检测模型的训练,得到相应的.pt、.onnx检测模型文件。随后,对模型计算量和硬件成本进行综合考量,完成边缘计算设备选型。最后,根据不同的硬件设备,将.pt或onnx模型文件转化成适配对应硬件平台的模型文件再进行推理(如瑞芯微的rknn格式、昇腾的om格式)。
目前网上大多数资料用的是YOLOV5官方源码提供的Python推理版本,然而实际业务场景往往需要基于C++在板子上完成模型推理。这就涉及到了一些模型输入预处理,输出后处理的问题,本文将简单介绍如何利用C++实现YOLOV5的后处理NMS算法。


一、为什么需要NMS算法?

先不思考什么是NMS,先思考为什么需要引入这个算法:
以YOLOV5为例,假设YOLOV5的输入图像大小为320x320x3,那么输出特征图的大小就为40x40、20x20、10x10。输出特征图的每个点都铺设了3个锚框,故最终有(40x40+20x20+10x10)x3个预测框。实际业务场景不可能有这么多的预测目标,我们需要先基于每个框的置信度筛除一批无效预测框(这一步还不是NMS,只是基于置信度进行筛除,因为大多数框都是无效框,利用置信度可以筛除90%以上的预测框)。
在这里插入图片描述
筛除了一批预测框后,由于目标附近可能会有多个预测框的置信度较高(也就是有多个预测框同时选中了目标),因此我们需要从中选取一个作为结果输出,这就需要引入一种滤除算法消除其它预测框,YOLO中用的就是NMS。

二、什么是NMS算法?

非极大值抑制(NMS),如名字所示,目的在于抑制非极大值的预测框。那么什么是极大值呢,其实就是局部区域内可信度得分最高的预测框。NMS算法的作用就是抑制局部区域内得分较低的预测框,最后保留那个极大值预测框。
对于目标检测场景,为了解决同一个目标被多个锚框选中的问题,我们引入了非极大值抑制算法(NMS),局部区域内只保留一个得分最高的目标框。

三、如何使用C++编写一个NMS算法

1、预测框定义

typedef struct Box{float x;	//预测框左上角坐标xfloat y;	//预测框左上角坐标yfloat w;	//框宽float h;	//框高float score;  //得分
}Box;

假设预测框的结构体定义如上所示,Box结构体中包含了预测框的位置、大小以及该框的得分。注意(需提前处理YOLO的输出内容,将输出内容都转化为Box结构体变量,此处省略该代码)

2、滤除无效框

NMS算法的思路如下:
(1)将所有预测框按照得分从高到低进行排序。
(2)从得分最高的预测框开始,依次遍历排序后的预测框列表中的每一个预测框,计算它与列表中后续所有预测框之间的IOU值。在计算IOU后,将那些IOU值大于预设阈值的后续预测框从候选框列表中移除。
(3)完成上述步骤后,继续遍历候选框列表中的下一个预测框,重复执行上述计算IOU和剔除高重叠预测框的过程,直到候选框列表中的所有预测框都被遍历完毕。

因此,我们需要先对预测框进行排序,假设预测框全都存放在vector类对象boxVec中,那么我们需要对boxVec内的全部预测框进行排序。

bool compare(Box b1, Box b2)
{return b1.score>b2.score? true:false;
}
vector<Box> boxVec;
sort(boxVec.begin(), boxVec.end(), compare);

随后编写一个计算两个预测框IOU的函数:

float IOU(Box b1, Box b2)
{  float x1 = max(b1.x, b2.x); 	//重叠框的四个坐标float x2 = min(b1.x + b1.w, b2.x + b2.w); float y1 = max(b1.y, b2.y);float y2 = min(b1.y + b1.h, b2.y + b2.h);float overlap_area = max(0.0f, x2 - x1) * max(0.0f, y2 - y1); //重叠区域大小  if (overlap_area == 0) return 0.0f; // 如果没有重叠,IoU为0  float union_area = b1.w * b1.h + b2.w * b2.h - overlap_area; //联合区域大小// 使用更常见的分母  float iou = overlap_area / union_area ;  return iou;  
}  

在这里插入图片描述

最后,利用排序好的boxVec和IOU函数完成无效框滤除:

size_t i = 0;  
float nms_ratio = 0.5;  
while(i < boxVec.size())  
{  size_t j = i + 1;  while(j < boxVec.size())  {  if(IOU(boxVec[i], boxVec[j]) > nms_ratio)  {  // 删除元素,并且不增加 j 的值  boxVec.erase(boxVec.begin() + j);  }  else  {  // 如果没有删除元素,则增加 j  j++;  }  }  i++;  
}

至此,boxVec中重叠的预测框就被滤除了。


总结

本文基于C++编写了一个简化版的NMS代码,简单介绍了相关的设计思路,实际使用可能仍需优化或存在疏漏,具体需根据业务需求动态调整代码。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【Kubernetes】常见面试题汇总(十三)
  • 5.1 溪降技术:个人装备
  • 04_Python数据类型_列表
  • 【运维监控】Prometheus+grafana+kafka_exporter监控kafka运行情况
  • linux 操作系统下的cut命令介绍和使用案例
  • JVM字节码
  • UVA1395 Slim Span(最小生成树)
  • Unity3d 以鼠标位置点为中心缩放视角(正交模式下)
  • 详解c++多态---上
  • 动态规划---不相交的线
  • 【前端】ref引用的作用
  • Golang、Python、C语言、Java的圆桌会议
  • Vue.js 计算属性
  • 数据结构:堆的算法
  • Nginx 文件名逻辑漏洞(CVE-2013-4547)
  • Angular 响应式表单 基础例子
  • CSS 提示工具(Tooltip)
  •  D - 粉碎叛乱F - 其他起义
  • gulp 教程
  • MaxCompute访问TableStore(OTS) 数据
  • nodejs实现webservice问题总结
  • Redis的resp协议
  • SAP云平台运行环境Cloud Foundry和Neo的区别
  • tensorflow学习笔记3——MNIST应用篇
  • Xmanager 远程桌面 CentOS 7
  • 反思总结然后整装待发
  • 基于webpack 的 vue 多页架构
  • 前端面试题总结
  • 入口文件开始,分析Vue源码实现
  • 使用 Docker 部署 Spring Boot项目
  • 线性表及其算法(java实现)
  • 小程序滚动组件,左边导航栏与右边内容联动效果实现
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • LIGO、Virgo第三轮探测告捷,同时探测到一对黑洞合并产生的引力波事件 ...
  • ​​​​​​​Installing ROS on the Raspberry Pi
  • ​configparser --- 配置文件解析器​
  • ​sqlite3 --- SQLite 数据库 DB-API 2.0 接口模块​
  • ​学习笔记——动态路由——IS-IS中间系统到中间系统(报文/TLV)​
  • ![CDATA[ ]] 是什么东东
  • (1)(1.11) SiK Radio v2(一)
  • (1)SpringCloud 整合Python
  • (2024,LoRA,全量微调,低秩,强正则化,缓解遗忘,多样性)LoRA 学习更少,遗忘更少
  • (C#)if (this == null)?你在逗我,this 怎么可能为 null!用 IL 编译和反编译看穿一切
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (动手学习深度学习)第13章 计算机视觉---图像增广与微调
  • (仿QQ聊天消息列表加载)wp7 listbox 列表项逐一加载的一种实现方式,以及加入渐显动画...
  • (附源码)springboot学生选课系统 毕业设计 612555
  • (附源码)计算机毕业设计SSM基于java的云顶博客系统
  • (入门自用)--C++--抽象类--多态原理--虚表--1020
  • (三)Honghu Cloud云架构一定时调度平台
  • (三)模仿学习-Action数据的模仿
  • (十)DDRC架构组成、效率Efficiency及功能实现
  • (四) Graphivz 颜色选择
  • (四)【Jmeter】 JMeter的界面布局与组件概述
  • (完整代码)R语言中利用SVM-RFE机器学习算法筛选关键因子