当前位置: 首页 > news >正文

bytetrack漏检补齐

bytetrack漏检补齐

1.人流慢速运动,跟踪效果比较好,偶尔有漏检,跟踪可以自动补齐。

2.快速运动,频繁遮挡,效果可能不好

*如果漏检,倒着跟踪,把丢失的检测框拷贝出来,保留进行跟踪。

有时候效果不是很好

from collections import defaultdict
import cv2
import numpy as np
import torchvision
from ultralytics import YOLO
import pickle
import torch
from torchvision.ops import box_iou
from log import logger
import time
import os
from addict import Dict
from track.byte_tracker import BYTETracker
import mathdef get_color(idx):idx = idx * 5color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)return colorclass YOLO_Class():def __init__(self, model_path, device="cuda:0"):self.model = YOLO(model_path)  # YOLO‑12 检测 + 跟踪self.par_args = Dict({"track_thresh": 0.5, "track_buffer": 30, "match_thresh": 0.9, "min_box_area": 10, "mot20": False})self.tracker = BYTETracker(self.par_args, frame_rate=20)def yolo_byte_track(self,detect_bboxes, frame):title_color = (0, 255, 255)person_sum = 0# print(f"bboxes: {detect_bboxes}")if len(detect_bboxes) > 0:if len(detect_bboxes) > 4:self.par_args.track_buffer = 60self.par_args.match_thresh = 1.6else:self.par_args.track_buffer = 30self.par_args.match_thresh = 0.9online_targets = self.tracker.update(np.array(detect_bboxes), [frame.shape[0], frame.shape[1]],(frame.shape[0], frame.shape[1]), self.par_args)# print("len(det)", len(detect_bboxes), "len track", len(online_targets))for index, t in enumerate(online_targets):tlwh = t.tlwhx1, y1, w, h = tlwhif w > 0 and h > 0:bbbb = t.track_idperson_sum = max(person_sum, bbbb)box_color = get_color(t.track_id)intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))cv2.rectangle(frame, intbox[0:2], intbox[2:4], color=box_color, thickness=2)hull = [[x1, y1], [x1 + w, y1], [x1 + w, y1 + h], [x1, y1 + h]]# for index, point in enumerate(track_dict[bbbb]):# dist = cv2.pointPolygonTest(np.array(hull).astype(np.int32), tuple(point), True)#<0 out >0 in# if index==len(track_dict[t.track_id])-4 and t.track_id < 3:#     print('----------------', abs(point[0] - (x1 + w / 2)), abs(point[1]-(y1+h)))# cv2.rectangle(frame, (intbox[2],intbox[1]), (int(intbox[2]+70),int(intbox[1]+80)), color=box_color , thickness=1)cv2.putText(frame, f'{bbbb} {t.score:.2f} ', (intbox[0], intbox[1] - 5), cv2.FONT_HERSHEY_PLAIN,1.8, title_color, thickness=2)return framedef get_bytetrack_bbox(self, video_path, video_id, output_path="", debug:bool=False):debug_dir = f"yolov12/debug/{video_id}" if debug else Noneos.makedirs(debug_dir, exist_ok=True)  # 确保调试目录存在# ----------------- 基本参数 -----------------track_history = defaultdict(list)  # 保存每个 track 的历史中心点cap = cv2.VideoCapture(video_path)if not cap.isOpened():raise RuntimeError(f"无法打开视频: {video_path}")fps = cap.get(cv2.CAP_PROP_FPS) or 30  # 有些文件读不到 FPS,给默认w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))logger.info(f"视频总帧数: {total_frames}, fps: {fps}, 宽: {w}, 高: {h}")frame_id = 0fourcc = cv2.VideoWriter_fourcc(*"mp4v")out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))if not out.isOpened():raise RuntimeError("VideoWriter 初始化失败,请检查编码器 fourcc 或路径。")last_box=[]while cap.isOpened():ok, frame = cap.read()if not ok:break# YOLO11 跟踪(persist=True 保持 track ID)t0 = time.time()results = detect_image_yolo(self.model,frame)pic_h,pic_w = frame.shape[:2]  # if frame_id%4==3:#     results = np.delete(results, 1, axis=0)pad_count = len(last_box) - len(results)if pad_count>0 and 0:tracker2 = BYTETracker(self.par_args, frame_rate=3)track_now = tracker2.update(results, (pic_h,pic_w),(pic_h,pic_w), self.par_args)track_last = tracker2.update(last_box, (pic_h,pic_w),(pic_h,pic_w), self.par_args)last_ids = set(t.track_id for t in track_last)b_ids = set(t.track_id for t in track_now)# 找出 a 中比 b 多出来的所有 track_idextra_ids = last_ids - b_ids# 根据 track_id 提取出对应的完整对象(如 STrack)extra_targets = [t for t in track_last if t.track_id in extra_ids]for t in extra_targets:x1, y1, w, h = t.tlwhprint('add box', frame_id,x1, y1, w, h)box_lost=np.asarray([x1, y1, x1 + w, y1 + h,t.score,0])results = np.vstack([results, box_lost])last_box=resultst1 = time.time()frame = self.yolo_byte_track(results, frame)print(f"{frame_id} det_track time {time.time() - t0:.3f}s track_time {time.time() - t1:.3f}s")if np.prod(frame.shape[:2]) > 1000 * 1300:x_scale = np.sqrt(1000 * 1200 / np.prod(frame.shape[:2]))frame = cv2.resize(frame, None, fx=x_scale, fy=x_scale, interpolation=cv2.INTER_AREA)cv2.imshow("YOLO Track", frame)if cv2.waitKey(0) & 0xFF == 27:   # Esc to quitbreak# 写入输出视频out.write(frame)frame_id += 1def detect_image_yolo(yolo_model,image, imgsz=640, conf=0.4, min_area=60*40, max_len=0):with torch.no_grad():results = yolo_model(image, verbose=False, imgsz=imgsz, conf=conf)cls = results[0].boxes.cls.int().cpu()indices = torch.where(cls == 0)[0]  # 只保留 person 类别if len(indices) == 0:return np.empty((0, 6))  # 返回空但保持 shape 正确labels = results[0].boxes.cls[indices]boxes = results[0].boxes.xyxy[indices]scores = results[0].boxes.conf[indices]if len(boxes) == 0:return np.empty((0, 6))boxes = boxes.float()keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold=0.5)boxes = boxes[keep_indices]scores = scores[keep_indices]labels = labels[keep_indices]#面积过滤areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])area_mask = areas >= min_areaboxes = boxes[area_mask]scores = scores[area_mask]labels = labels[area_mask]if len(boxes) == 0:return np.empty((0, 6))# 转换为 numpy 并拼接成 ByteTrack 格式boxes = boxes.cpu().numpy()scores = scores.cpu().numpy()labels = labels.cpu().numpy()dets = np.concatenate([boxes, scores[:, None], labels[:, None]], axis=1)  # [N, 6]return detsif __name__ == "__main__":mp4_path = r"C:\Users\Administrator\Videos\yundong\20250226162704517\20250226162704517.mp4"mp4_path = r"F:\data\lanqiu\150_30\150_30.mp4"mp4_path = r"E:\data\tiaosheng\0706\5s.mp4"video_id = os.path.basename(mp4_path).split(".")[0]  # 从路径中提取视频 IDyolo_path= r"F:\BaiduNetdiskDownload\tiaosheng_new\model\best_new.pt"yolo_cls = YOLO_Class(yolo_path)yolo_cls.get_bytetrack_bbox(mp4_path, video_id, output_path=f"{video_id}_tracked.mp4", debug=True)

http://www.lryc.cn/news/589571.html

相关文章:

  • 2025年夏Datawhale AI夏令营机器学习
  • 数据怎么分层?从ODS、DW、ADS三大层一一拆解!
  • Flink Watermark原理与实战
  • omniparser v2 本地部署及制作docker镜像(20250715)
  • 驱动开发系列61- Vulkan 驱动实现-SPIRV到HW指令的实现过程(2)
  • 定时器更新中断与串口中断
  • Claude 背后金主亚马逊亲自下场,重磅发布 AI 编程工具 Kiro 现已开启免费试用
  • CUDA 环境下 `libcuda.so` 缺失问题解决方案
  • 2-Nodejs运行JS代码
  • 基于按键开源MultiButton框架深入理解代码框架(二)(指针的深入理解与应用)
  • css-css执行的三种方式和css选择器
  • 【leetcode】263.丑数
  • 邮件伪造漏洞
  • 再见吧,Windows自带记事本,这个轻量级文本编辑器太香了
  • Rust基础[part4]_基本类型,所有权
  • Java 集合 示例
  • 【Qt】插件机制详解:从原理到实战
  • redisson tryLock
  • HAProxy双机热备,轻松实现负载均衡
  • [Python] -实用技巧6-Python中with语句和上下文管理器解析
  • Hessian矩阵在多元泰勒展开中如何用于构造优化详解
  • 记一次POST请求中URL中文参数乱码问题的解决方案
  • LeetCode 1888. 使二进制字符串字符交替的最少反转次数
  • 整除分块练习题
  • 使用Spring Cloud LoadBalancer报错java.lang.IllegalStateException
  • AI助手指南:从零开始打造Python学习环境(VSCode + Lingma/Copilot + Anaconda + 效率工具包)
  • 学习秒杀系统-实现秒杀功能(商品列表,商品详情,基本秒杀功能实现,订单详情)
  • Sharding-JDBC 分布式事务实战指南:XA/Seata 方案解析(三)
  • 2HDMI/1DP转EDP/LVDS,支持4K,144HZ和240HZ.
  • LSA链路状态通告