当前位置: 首页 > news >正文

基于华为atlas下的yolov5+BoT-SORT/ByteTrack煤矿箕斗状态识别大探索

写在前面:

本项目的代码原型基于yolov5+yolov8。其中检测模型使用的yolov5,跟踪模型使用的yolov8。

这里说明以下,为什么不整体都选择yolov8呢,v8无疑是比v5优秀的,但是atlas这块经过不断尝试没有过去,所以只能选择v5。那为什么跟踪模型选择yolov8呢,其实我这里要做的是实时视频的处理,我也不想使用deepsort那种带识别模型的笨重型跟踪框架,看了yolov8的代码,觉得相当可以,就选择了yolov8中的跟踪。原本我以为自己的水平是扣不出这块跟踪代码的,毕竟是网上大佬们经过多年迭代修改的代码,代码水平是远在我之上的。做好一件事情的最好方法,就是立刻开始做,在连续加班了2个晚上后,终于扣出来了,过程是曲折的,结果是美好的。一与一,勇者得强尔。

参考代码git链接:

Yolov5:https://github.com/ultralytics/yolov5.git (v6.1版本)

Yolov8:https://github.com/ultralytics/ultralytics.git

项目目的:

识别箕斗的状态,运行(run),静止(still),识别画面中箕斗数量(num)。

目前本文方法同时支持BoT-SORT/ByteTrack两种跟踪算法。

跟踪算法浅析:

BoT-SORT 算法:

BoT-SORT(Bottleneck Transformers for Multiple Object Tracking and Segmentation)是一种基于深度学习的多目标跟踪算法。

它的主要特点包括:

  1. 利用了 Transformer 架构的优势,能够对目标的特征进行有效的编码和关联。例如,在处理复杂场景中的目标时,能够捕捉到长距离的依赖关系,从而更准确地跟踪目标。
  2. 对目标的外观特征和运动特征进行融合。
    通过结合外观信息和运动预测,提高了跟踪的准确性和稳定性。比如在目标被遮挡或短暂消失后重新出现时,能够更可靠地重新识别和跟踪。

ByteTrack 算法:

ByteTrack 是一种高效且准确的多目标跟踪算法。

其突出特点如下:

  1. 采用了一种简单而有效的关联策略。
    它不仅仅依赖于高分检测框,还充分利用低分检测框中的信息,大大减少了目标丢失的情况。例如,在车辆密集的交通场景中,能够准确跟踪那些被部分遮挡的车辆。
  2. 具有较高的计算效率。
    能够在保证跟踪效果的同时,降低计算成本,适用于实时应用场景。

区别:

  1. 准确性:

BoT-SORT 在 MOT17 和 MOT20 测试集的 MOTChallenge 数据集中排名第一,对于 MOT17 实现了 80.5 MOTA、80.2 IDF1 和 65.0 HOTA。而 ByteTrack 在速度达到30FPS(单张 V100)的情况下,各项指标也均有突破。相比 deep sort,ByteTrack 在遮挡情况下的提升非常明显。

  1. 速度:

ByteTrack 预测的速度感觉比 BoT-SORT 快一些,更加流畅。

  1. 其他指标:

BoT-SORT 可以很好地应对目标被遮挡或短暂消失后重新出现的情况,能够更可靠地重新识别和跟踪。而 ByteTrack 没有采用外表特征进行匹配,所以跟踪的效果非常依赖检测的效果,也就是说如果检测器的效果很好,跟踪也会取得不错的效果,但是如果检测的效果不好,那么会严重影响跟踪的效果。

数据集准备:

数据基于视频分解而成图片得到,基于labelimg标注,自己大概标了4天吧,一共872张。

Yolov5模型训练:

数据集目录格式如下,

data/jidou.yaml配置文件内容,

path: ./datasets/jidou  # dataset root dir
train: images/train  # train images (relative to 'path') 128 images
val: images/train  # val images (relative to 'path') 128 images
test:  images/train # test images (optional)# Classes
nc: 1  # number of classes
names: ['jidou']

开始训练,

python3 train.py --img 640 --epochs 100 --data ./data/jidou.yaml --weights yolov5s.pt

模型转化,pt模型转化为onnx,

python export.py --weights ./jidou_model/best.pt –simplify

onnx模型转化为atlas模型,

atc  --input_shape="images:1,3,640,640" --out_nodes="/model.24/Transpose:0;/model.24/Transpose_1:0;/model.24/Transpose_2:0" --output_type=FP32 --input_format=NCHW --output="./yolov5_add_bs1_fp16" --soc_version=Ascend310P3 --framework=5 --model="./best.onnx" --insert_op_conf=./insert_op.cfg

其中,fusion_result.json文件内容,

[{"graph_fusion": {"AConv2dMulFusion": {"effect_times": "0","match_times": "57"},"ConstToAttrPass": {"effect_times": "5","match_times": "5"},"ConvConcatFusionPass": {"effect_times": "0","match_times": "13"},"ConvFormatRefreshFusionPass": {"effect_times": "0","match_times": "60"},"ConvToFullyConnectionFusionPass": {"effect_times": "0","match_times": "60"},"ConvWeightCompressFusionPass": {"effect_times": "0","match_times": "60"},"CubeTransFixpipeFusionPass": {"effect_times": "0","match_times": "3"},"FIXPIPEAPREQUANTFUSIONPASS": {"effect_times": "0","match_times": "60"},"FIXPIPEFUSIONPASS": {"effect_times": "0","match_times": "60"},"MulAddFusionPass": {"effect_times": "0","match_times": "14"},"MulSquareFusionPass": {"effect_times": "0","match_times": "57"},"RefreshInt64ToInt32FusionPass": {"effect_times": "1","match_times": "1"},"RemoveCastFusionPass": {"effect_times": "0","match_times": "123"},"ReshapeTransposeFusionPass": {"effect_times": "0","match_times": "3"},"SplitConvConcatFusionPass": {"effect_times": "0","match_times": "13"},"TransdataCastFusionPass": {"effect_times": "0","match_times": "63"},"TransposedUpdateFusionPass": {"effect_times": "3","match_times": "3"},"V200NotRequantFusionPass": {"effect_times": "0","match_times": "7"},"ZConcatDFusionPass": {"effect_times": "0","match_times": "13"}},"session_and_graph_id": "0_0","ub_fusion": {"AutomaticUbFusion": {"effect_times": "1","match_times": "1","repository_hit_times": "0"},"TbeAippCommonFusionPass": {"effect_times": "1","match_times": "1","repository_hit_times": "0"},"TbeConvSigmoidMulQuantFusionPass": {"effect_times": "56","match_times": "56","repository_hit_times": "0"}}
}]

insert_op.cfg文件内容,

aipp_op {
aipp_mode : static
related_input_rank : 0
input_format : YUV420SP_U8
src_image_size_w : 640
src_image_size_h : 640
crop : false
csc_switch : true
rbuv_swap_switch : false
matrix_r0c0 : 256
matrix_r0c1 : 0
matrix_r0c2 : 359
matrix_r1c0 : 256
matrix_r1c1 : -88
matrix_r1c2 : -183
matrix_r2c0 : 256
matrix_r2c1 : 454
matrix_r2c2 : 0
input_bias_0 : 0
input_bias_1 : 128
input_bias_2 : 128
var_reci_chn_0 : 0.0039216
var_reci_chn_1 : 0.0039216
var_reci_chn_2 : 0.0039216
}

jidou.names文件内容,

jidou
yolov5_add_bs1_fp16.cfg文件内容,
CLASS_NUM=1
BIASES_NUM=18
BIASES=10,13,16,30,33,23,30,61,62,45,59,119,116,90,156,198,373,326
SCORE_THRESH=0.25
#SEPARATE_SCORE_THRESH=0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001,0.001
OBJECTNESS_THRESH=0.0
IOU_THRESH=0.5
YOLO_TYPE=3
ANCHOR_DIM=3
MODEL_TYPE=2
RESIZE_FLAG=0
YOLO_VERSION=5

代码编写之跟踪代码剥离:

剥离得整体思路如下,

  1. 先吧原始代码跑起来,效果测试是对的。
  2. 熟悉代码,主要熟悉trackers下面的py文件,engine中的predictor.py,results.py,model.py。
  3. 熟悉跟踪的本质,其实就是2个函数,一个初始化函数,一个update函数。
  4. 将模型和跟踪部分先剥离开(使用model.predict替换model.track)。
  5. 剥离Results结构体(使用传统的list替换Results得到更加通用的上下文传递变量)。
  6. 实现update函数(自己写代码替换tracker.update函数和predictor.results[i].update(**update_args)函数)。
  7. 剥离跟踪的配置文件yaml文件,选择在跟踪函数初始化赋值。
  8. 剥离其他依赖文件,metrics.py,ops.py。
  9. 剥离torch依赖,metrics.py中的batch_probiou函数基于numpy实现。
  10. 细节处bug修改。

最终跟踪代码track.py如下,

import os
import json
import cv2
import numpy as np
from plots import box_label, colorsfrom collections import defaultdict
from trackers.bot_sort import BOTSORT
from trackers.byte_tracker import BYTETracker
from names import namesclass TRACK(object):def __init__(self):#跟踪self.frame_rate=30#BOTSORTself.tracker = BOTSORT(frame_rate=self.frame_rate)#BYTETracker#self.tracker = BYTETracker(frame_rate=self.frame_rate)self.track_history = defaultdict(lambda: [])self.move_state = defaultdict(lambda: [])self.move_state_dict = {0:"still" ,1:"run"}self.distance = 5def track(self, track_results, frame):if len(track_results[0]["cls"]) != 0:tracks = self.tracker.update(track_results[0], frame)if len(tracks) != 0:idx = tracks[:, -1].astype(int)if track_results[0]["id"] is not None:track_results[0]["id"] = np.array([track_results[0]["id"][i] for i in idx])else:track_results[0]["id"] = np.array(tracks[:, 4].astype(int))track_results[0]["cls"] = np.array([track_results[0]["cls"][i] for i in idx])track_results[0]["conf"] = np.array([track_results[0]["conf"][i] for i in idx])track_results[0]["xywh"] = np.array([track_results[0]["xywh"][i] for i in idx])#跟新track_history, move_stateboxes = track_results[0]["xywh"]clses = track_results[0]["cls"]track_ids = []if track_results[0]["id"] is not None:track_ids = track_results[0]["id"].tolist()# Your code for processing track_idselse:print("No tracks found in this frame")# Plot the tracksfor cls, box, track_id in zip(clses, boxes, track_ids):x, y, w, h = boxtrack = self.track_history[track_id]track.append((float(x+w/2.0), float(y+h/2.0)))  # x, y center pointif len(track) > 30:  # retain 90 tracks for 90 framestrack.pop(0)if len(track)>=self.frame_rate:if abs(track[-1][0]-track[0][0]) + abs(track[-1][1]-track[0][1])>= self.distance:self.move_state[track_id] = self.move_state_dict[1]else:self.move_state[track_id] = self.move_state_dict[0]else:self.move_state[track_id] = self.move_state_dict[0]return track_resultsdef draw(self, image, track_results):# draw the result and save imagefor index, info in enumerate(track_results[0]["xywh"]):xyxy = [int(info[0]), int(info[1]), int(info[0])+int(info[2]), int(info[1])+int(info[3])]classVec = int(track_results[0]["cls"][index])conf = float(track_results[0]["conf"][index])if track_results[0]["id"] is not None:id = int(track_results[0]["id"][index])else:id = ""if id =="":label = f'{names[classVec]} {conf:.4f} track_id {id}'else:label = f'{names[classVec]} {conf:.4f} track_id {id} state {self.move_state[id]}'annotated_frame = box_label(image, xyxy, label, color=colors[classVec])cv2.putText(annotated_frame, "num:{}".format(len(track_results[0]["cls"])), (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 0),thickness=2, lineType=cv2.LINE_AA)boxes = track_results[0]["xywh"]clses = track_results[0]["cls"]track_ids = []if track_results[0]["id"] is not None:track_ids = track_results[0]["id"].tolist()# Your code for processing track_idselse:print("No tracks found in this frame")# Plot the tracksfor cls, box, track_id in zip(clses, boxes, track_ids):x, y, w, h = boxtrack = self.track_history[track_id]# Draw the tracking linespoints = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))cv2.polylines(annotated_frame,[points],isClosed=False,color=colors[cls],thickness=4,)return annotated_frame

metrics.py代码如下,

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Model validation metrics."""import numpy as npdef bbox_ioa(box1, box2, iou=False, eps=1e-7):"""Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.Args:box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes.box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes.iou (bool): Calculate the standard IoU if True else return inter_area/box2_area.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area."""# Get the coordinates of bounding boxesb1_x1, b1_y1, b1_x2, b1_y2 = box1.Tb2_x1, b2_y1, b2_x2, b2_y2 = box2.T# Intersection areainter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)# Box2 areaarea = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)if iou:box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)area = area + box1_area[:, None] - inter_area# Intersection over box2 areareturn inter_area / (area + eps)def batch_probiou(obb1, obb2, eps=1e-7):"""Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.Args:obb1 ( np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.obb2 ( np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(np.ndarray): A tensor of shape (N, M) representing obb similarities."""x1, y1 = np.split(obb1[..., :2], 2, axis=-1)x2, y2 = (x.squeeze(-1)[None] for x in np.split(obb2[..., :2],2, axis=-1))a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))t1 = (((a1 + a2) * np.power(y1 - y2, 2) + (b1 + b2) * np.power(x1 - x2, 2)) / ((a1 + a2) * (b1 + b2) - np.power(c1 + c2, 2) + eps)) * 0.25t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - np.power(c1 + c2, 2) + eps)) * 0.5t3 = np.log(((a1 + a2) * (b1 + b2) - np.power(c1 + c2, 2))/ (4 * np.clip(a1 * b1 - np.power(c1, 2),0, np.inf) * np.sqrt(np.clip(a2 * b2 - np.power(c2, 2), 0, np.inf)) + eps)+ eps) * 0.5bd = np.clip(t1 + t2 + t3, eps, 100.0)hd = np.sqrt(1.0 - np.exp(-bd) + eps)return 1 - hddef _get_covariance_matrix(boxes):"""Generating covariance matrix from obbs.Args:boxes (np.ndarray): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.Returns:(np.ndarray): Covariance metrixs corresponding to original rotated bounding boxes."""# Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.gbbs = np.concatenate((np.power(boxes[:, 2:4],2) / 12, boxes[:, 4:]), axis=-1)a, b, c = np.split(gbbs, 3, axis=-1)cos = np.cos(c)sin = np.sin(c)cos2 = np.power(cos, 2)sin2 = np.power(sin, 2)return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin

代码编写之检测代码yolov5.py实现:

import os
import json
import cv2
from StreamManagerApi import StreamManagerApi, MxDataInput
import numpy as np
from plots import box_label, colors
from utils import  scale_coords, xyxy2xywh, is_legal, preprocfrom track import TRACKfrom names import names
import timeclass YOLOV5(object):def __init__(self):# init stream managerself.streamManagerApi = StreamManagerApi()ret = self.streamManagerApi.InitManager()if ret != 0:print("Failed to init Stream manager, ret=%s" % str(ret))exit()# create streams by pipeline config filewith open("./pipeline/jidou.pipeline", 'rb') as f:pipelineStr = f.read()ret = self.streamManagerApi.CreateMultipleStreams(pipelineStr)if ret != 0:print("Failed to create Stream, ret=%s" % str(ret))exit()def process(self, image):# Construct the input of the streamdataInput = MxDataInput()h0, w0 = image.shape[:2]r = 640 / max(h0, w0)  # ratioinput_shape = (640, 640)pre_img = preproc(image, input_shape)[0]pre_img = np.ascontiguousarray(pre_img)image_bytes = cv2.imencode('.jpg', pre_img)[1].tobytes()dataInput.data = image_bytes# Inputs data to a specified stream based on streamName.STREAMNAME = b'classification+detection'INPLUGINID = 0uniqueId = self.streamManagerApi.SendDataWithUniqueId(STREAMNAME, INPLUGINID, dataInput)if uniqueId < 0:print("Failed to send data to stream.")exit()# Obtain the inference result by specifying streamName and uniqueId.inferResult = self.streamManagerApi.GetResultWithUniqueId(STREAMNAME, uniqueId, 10000)if inferResult.errorCode != 0:print("GetResultWithUniqueId error. errorCode=%d, errorMsg=%s" % (inferResult.errorCode, inferResult.data.decode()))exit()results = json.loads(inferResult.data.decode())track_results = [{"id":None, "className":[],"cls":[],"conf":[],  "xywh":[]}]for num, info in enumerate(results['MxpiObject']):xyxy = [int(info['x0']), int(info['y0']), int(info['x1']), int(info['y1'])]xyxy = scale_coords(pre_img.shape[:2], np.array(xyxy), image.shape[:2])classVec = info["classVec"]track_results[0]["className"].append(names[classVec[0]["classId"]])track_results[0]["cls"].append(classVec[0]["classId"])track_results[0]["conf"].append(classVec[0]["confidence"])track_results[0]["xywh"].append([xyxy[0], xyxy[1], xyxy[2]-xyxy[0], xyxy[3]-xyxy[1]])track_results[0]["cls"] = np.array(track_results[0]["cls"])track_results[0]["conf"] = np.array(track_results[0]["conf"])track_results[0]["xywh"] = np.array(track_results[0]["xywh"])return track_resultsdef __del__(self):# destroy streamsself.streamManagerApi.DestroyAllStreams()def draw(self, image, track_results):# draw the result and save imagefor index, info in enumerate(track_results[0]["xywh"]):xyxy = [int(info[0]), int(info[1]), int(info[0])+int(info[2]), int(info[1])+int(info[3])]classVec = int(track_results[0]["cls"][index])conf = float(track_results[0]["conf"][index])if track_results[0]["id"] is not None:id = int(track_results[0]["id"][index])else:id = ""label = f'{names[classVec]} {conf:.4f}'annotated_frame = box_label(image, xyxy, label, color=colors[classVec])return annotated_framedef test_img():# read imageORI_IMG_PATH = "./test_images/00004.jpg"image = cv2.imread(ORI_IMG_PATH, 1)yolov5 = YOLOV5()track_results = yolov5.process(image)print(track_results)save_img = yolov5.draw(image, track_results)cv2.imwrite('./result.jpg', save_img)def test_video():yolov5 = YOLOV5()tracker = TRACK()# Open the video filevideo_path = "./test_images/jidou.mp4"cap = cv2.VideoCapture(video_path)fourcc = cv2.VideoWriter_fourcc('X', 'V', 'I', 'D') # 确定视频被保存后的编码格式output = cv2.VideoWriter("output.mp4", fourcc, 20, (1280, 720)) # 创建VideoWriter类对象# Loop through the video frameswhile cap.isOpened():# Read a frame from the videosuccess, frame = cap.read()if success:# Run YOLOv8 tracking on the frame, persisting tracks between framest1 = time.time()track_results = yolov5.process(frame)t2 = time.time()track_results = tracker.track(track_results, frame)t3 = time.time()annotated_frame = tracker.draw(frame, track_results)t4 = time.time()print("time", t2-t1, t3-t2, t4-t3, t4-t1)output.write(annotated_frame)# Display the annotated frame#cv2.imshow("YOLOv8 Tracking", annotated_frame)# Break the loop if 'q' is pressedif cv2.waitKey(1) & 0xFF == ord("q"):breakelse:# Break the loop if the end of the video is reachedbreak# Release the video capture object and close the display windowcap.release()cv2.destroyAllWindows()if __name__ == '__main__':#test_img()test_video()

最终整体代码目录结构:

最终效果:

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • ES6 (一)——ES6 简介及环境搭建
  • 大模型学习微调资源
  • django之BaseSerializer
  • Go语言基础--数据类型(布尔、字符串)
  • QT 按钮延时以及滚动条提示
  • centos 下如何安装openjdk21
  • Springboot实现doc,docx,xls,xlsx,ppt,pptx,pdf,txt,zip,rar,图片,视频,音频在线预览功能,你学“废”了吗?
  • 【qt】跳转到另一个界面
  • 安全密码算法:SM3哈希算法介绍
  • 电子电气架构---EEA的发展趋势
  • 量化交易的基石:ExchangeSdk
  • (自用)仿写程序
  • 使用 Go 语言将 Base64 编码转换为 PDF 文件
  • 深入探索Amazon EC2:解锁云端计算的无限可能
  • 使用 grep 进行文本文件搜索
  • -------------------- 第二讲-------- 第一节------在此给出链表的基本操作
  • Essential Studio for ASP.NET Web Forms 2017 v2,新增自定义树形网格工具栏
  • Java精华积累:初学者都应该搞懂的问题
  • JDK9: 集成 Jshell 和 Maven 项目.
  • JS专题之继承
  • Unix命令
  • XML已死 ?
  • 测试如何在敏捷团队中工作?
  • 道格拉斯-普克 抽稀算法 附javascript实现
  • 关于 Cirru Editor 存储格式
  • 使用Envoy 作Sidecar Proxy的微服务模式-4.Prometheus的指标收集
  • 吴恩达Deep Learning课程练习题参考答案——R语言版
  • elasticsearch-head插件安装
  • Java数据解析之JSON
  • #70结构体案例1(导师,学生,成绩)
  • #LLM入门|Prompt#1.7_文本拓展_Expanding
  • (2)Java 简介
  • (2024)docker-compose实战 (8)部署LAMP项目(最终版)
  • (26)4.7 字符函数和字符串函数
  • (C语言)二分查找 超详细
  • (附源码)计算机毕业设计大学生兼职系统
  • (力扣记录)1448. 统计二叉树中好节点的数目
  • (一)项目实践-利用Appdesigner制作目标跟踪仿真软件
  • *Django中的Ajax 纯js的书写样式1
  • .NET Core 成都线下面基会拉开序幕
  • .net core 依赖注入的基本用发
  • .NET 服务 ServiceController
  • .NET 使用 ILRepack 合并多个程序集(替代 ILMerge),避免引入额外的依赖
  • .NET 中什么样的类是可使用 await 异步等待的?
  • .NET4.0并行计算技术基础(1)
  • .NET开源纪元:穿越封闭的迷雾,拥抱开放的星辰
  • .net中的Queue和Stack
  • ::
  • @requestBody写与不写的情况
  • @transactional 方法执行完再commit_当@Transactional遇到@CacheEvict,你的代码是不是有bug!...
  • [20150904]exp slow.txt
  • [20170705]lsnrctl status LISTENER_SCAN1
  • [4.9福建四校联考]
  • [Android]How to use FFmpeg to decode Android f...
  • [AutoSar]BSW_Com07 CAN报文接收流程的函数调用