当前位置: 首页 > news >正文

应用DeepSORT实现目标跟踪

在ByteTrack被提出之前,可以说DeepSORT是最好的目标跟踪算法之一。本文,我们就来应用这个算法实现目标跟踪。

DeepSORT的官方网址是https://github.com/nwojke/deep_sort。但在这里,我们不使用官方的代码,而使用第三方代码,其网址为https://github.com/levan92/deep_sort_realtime。

下面我们就来应用DeepSORT。首先在虚拟环境内安装必要的软件包:

conda install python=3.8
pip install deep-sort-realtime

可以看出,DeepSORT算法只是需要几个常规的软件包:numpy、scipy和opencv-python,对用户十分友好。

使用DeepSORT也很方便,先导入DeepSORT:

from deep_sort_realtime.deepsort_tracker import DeepSort

实例化:

tracker = DeepSort()

DeepSort有一些输入参数,在这里只介绍几个常用的参数:

max_iou_distance:IoU的门控阈值,大于该值的关联会被忽略,默认值为0.7

max_age:当遗漏次数大于该值时轨迹会被删除,默认值为30

n_init:在初始阶段轨迹被保留的帧数,默认值为3

nms_max_overlap:非最大值抑制阈值,如果该值为1.0,表示不使用非最大值抑制,默认值为1.0

max_cosine_distance:余弦距离阈值,默认值为0.2

nn_budget:外观描述符的最大尺寸(int类型),如果为None,则不强制执行,默认值为None

实现目标跟踪:

tracks = tracker.update_tracks(bbs, frame=frame)

bbs为目标检测器的结果列表,每个结果是一个元组,形式为([left,top,w,h],置信值,类型),其中类型为字符串型

frame为帧图像

输出tracks为目标跟踪结果,使用for循环可以得到各个目标的跟踪信息:

for track in tracks:

下面介绍一些track的常用属性和方法:

track_id:目标ID

orginal_ltwh、det_conf、det_class:分别表示目标边框信息、置信值和类型,这三个值都是由tracker.update_tracks传入系统的原始目标的信息,但此时已匹配上了目标ID

to_ltrb()和to_ltwh():得到目标边框信息,两者的形式不同

is_confirmed():表示如果该目标ID被确认,则返回True

下面我们就给出DeepSORT实现目标跟踪的完整程序,在这里,我们仍然使用YOLOv8作为目标检测器:

import numpy as np
import cv2
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSortmodel = YOLO('yolov8l.pt')cap = cv2.VideoCapture("D:/track/Highway Traffic.mp4")
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fNUMS = cap.get(cv2.CAP_PROP_FRAME_COUNT)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter("D:/track/mytrack.mp4", fourcc, fps, size)tracker = DeepSort(max_age=5)def box_label(image, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))cv2.rectangle(image, p1, p2, color, thickness=1, lineType=cv2.LINE_AA)if label:w, h = cv2.getTextSize(label, 0, fontScale=2 / 3, thickness=1)[0]  outside = p1[1] - h >= 3p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA)cv2.putText(image,label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),0, 2/3, txt_color, thickness=1, lineType=cv2.LINE_AA)while cap.isOpened():success, frame = cap.read()if success: results = model(frame,conf=0.4)outputs = results[0].boxes.data.cpu().numpy()detections = []if outputs is not None:for output in outputs:x1, y1, x2, y2 = list(map(int, output[:4]))if output[5] == 2:detections.append(([x1, y1, int(x2-x1), int(y2-y1)], output[4], 'car'))elif output[5] == 5:detections.append(([x1, y1, int(x2-x1), int(y2-y1)], output[4], 'bus'))elif output[5] == 7:detections.append(([x1, y1, int(x2-x1), int(y2-y1)], output[4], 'truck'))tracks = tracker.update_tracks(detections, frame=frame)for track in tracks:if not track.is_confirmed():continuetrack_id = track.track_idbbox = track.to_ltrb()box_label(frame, bbox, '#'+str(int(track_id))+ track.det_class , (167, 146, 11))cv2.putText(frame, "https://blog.csdn.net/zhaocj", (25, 50),cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)cv2.imshow("YOLOv8 Tracking", frame)videoWriter.write(frame)if cv2.waitKey(1) & 0xFF == ord("q"):breakelse:breakcap.release()
videoWriter.release()
cv2.destroyAllWindows()

http://www.lryc.cn/news/186346.html

相关文章:

  • Beyond Compare 4 30天评估到期 解决方法
  • 化妆品用乙基己基甘油全球市场总体规模2023-2029
  • springboot家政服务管理平台springboot29
  • 【网络安全】如何保护IP地址?
  • 2023年失业了,想学一门技术可以学什么?
  • MySQL-MVCC(Multi-Version Concurrency Control)
  • ArcGIS中的镶嵌数据集与接缝线
  • 网络安全工程师自主学习计划表(具体到阶段目标,保姆级安排,就怕你学不会!)
  • Linux 根据 PID 查看进程名称
  • Python二级 每周练习题21
  • 【算法训练-数组 三】【数组矩阵】螺旋矩阵、旋转图像、搜索二维矩阵
  • LED灯实验--汇编
  • Android多线程学习:线程池(一)
  • 网络安全(黑客技术)—小白自学笔记
  • 掌握核心技巧就能创建完美的目录!如何在Word中自动创建目录
  • 正则表达式中re.match、re.search、re.findall的用法和区别
  • 算法题:买卖股票的最佳时机含手续费(动态规划解法贪心解法-详解)
  • 【gcc】RtpTransportControllerSend学习笔记 4:码率分配
  • 「专题速递」AR协作、智能NPC、数字人的应用与未来
  • 什么是基于意图的网络(IBN)
  • 知识增强语言模型提示 零样本知识图谱问答10.8
  • 虚拟现实项目笔记:SDK、Assimp、DirectX Sample Browser、X86和X64
  • openwrt rm500u ncm方式拨号步骤记录
  • 使用js代码将一个值为“1=增量,2=全量“的字符串转化为一个数组,数据格式为[{value:““,label:“‘‘}]
  • 图片调色盘
  • 一文读懂Base64
  • CCF CSP认证 历年题目自练 Day20
  • 【Overload游戏引擎分析】从视图投影矩阵提取视锥体及overload对视锥体的封装
  • vue全局事件总线是什么?有什么用?解决了什么问题,与pinia有什么区别?
  • 【debian 12】:debian系统切换中文界面