当前位置: 首页 > news >正文

【YOLO系列】YOLOv5 NMS源码理解、更换为DIoU-NMS

代码来源:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite

使用的代码是YOLOv5 6.1版本

参考笔记:YOLOv5改进系列(八) 更换NMS非极大抑制DIoU-NMS、CIoU-NMS、EIoU-NMS、GIoU-NMS 、SIoU-NMS、Soft-NMS_diou nms-CSDN博客

yolov5 极大值抑制 nms 代码详解 - 金色旭光 - 博客园

https://zhuanlan.zhihu.com/p/511151467


目录

1.NMS源码理解

2.更换DIou-NMS


1.NMS源码理解

YOLOv5NMS的实现代码在utils/general.pynon_max_suppression

#对推理结果执行NMS
def non_max_suppression(prediction,#模型的预测结果,shape=[batch_size,预测框数量,5+类别数量=中心x+中心y+w+h+conf+类别数量]conf_thres=0.25,#置信度阈值,用于NMS,置信度低于此阈值的预测框会被去除iou_thres=0.45,#IoU阈值,用于NMS,去除冗余的预测框classes=None,#只对某些类别作NMS,None则表示所有类别都作NMSagnostic=False,#是否作类别无关的NMS,即所有预测框不分类别一起作NMS处理,通常不开启,都是各类别各自作NMSmulti_label=False,labels=(),max_det=300#每张图片作NMS之后剩余的最多预测框数):'''函数返回值:返回值output是一个列表,存放每张图片的检测结果eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]'''#类别数量ncnc = prediction.shape[2] - 5#符合置信度阈值的预测框bool数组,xc shape=[batch_size,预测框数量]xc = prediction[..., 4] > conf_thres#检查置信度、IoU阈值的有效性assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'#设置参数min_wh, max_wh = 2, 4096  #框的最小和最大宽高(像素)max_nms = 30000  #每张图片作NMS之前的最多预测框数time_limit = 10.0  #处理图片超过此时间则退出multi_label &= nc > 1  #没啥用t = time.time()  #记录开始时间output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]  #初始化返回值output#遍历每张图像的预测结果for xi, x in enumerate(prediction):'''xi:当前图片在batch中的idx:存放当前图片的预测框信息,shape=[预测框数量,5+类别数量]'''#仅保留大于置信度阈值的预测框,x shape=[预测框数量,5+类别数量]x = x[xc[xi]]#如果存在真实标签,则将其合并到预测结果中(这段代码不知道有什么用)if labels and len(labels[xi]):l = labels[xi]  #真实标签v = torch.zeros((len(l), nc + 5), device=x.device)  # 初始化与真实标签相同形状的张量v[:, :4] = l[:, 1:5]  # 提取真实框的坐标v[:, 4] = 1.0  # 置信度设为1.0v[range(len(l)), l[:, 0].long() + 5] = 1.0  # 设置类别x = torch.cat((x, v), 0)  # 合并预测框和真实框#如果预测框数量为0,则处理下一张图片if not x.shape[0]:continue#重置类别概率=conf置信度*原始类别概率x[:, 5:] *= x[:, 4:5]#将坐标值从(中心x, 中心y, w, h)转换为(x1, y1, x2, y2),box shape=[预测框数量,4=xyxy]box = xywh2xyxy(x[:, :4])#通常multi_label为False,执行else部分if multi_label:i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T  # 确定哪些框符合多标签条件x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)  # 合并框信息else:#将最大类别概率作为检测框的置信度存放于conf中,并将类别索引存放于j中conf, j = x[:, 5:].max(1, keepdim=True)x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]#合并xyxy+置信度+类别索引'''conf: shape=[预测框数,1=置信度]j: shape=[预测框数,1=类别索引]x: shape=[预测框数,6=xyxy+置信度+类别索引]'''#利用class进行过滤,筛选出指定的class,nms仅仅对指定的class进行nms;#若classes为None,则所有类别都需要作nmsif classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]#仅保留指定类别的预测框#预测框数量nn = x.shape[0]#如果没有预测框,则处理下一张图片if not n:continueelif n > max_nms: #如果作NMS之前预测框的数量大于max_nms,则按置信度排序并保留前max_nms个框x = x[x[:, 4].argsort(descending=True)[:max_nms]]#Batches NMS#这行代码是在多类别中应用NMS#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=原类别索引*max_wh#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]boxes = x[:, :4] + c#获取所有预测框的置信度,scores shape=[预测框数量,]scores = x[:, 4]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]i = torchvision.ops.nms(boxes, scores, iou_thres)#每张图片NMS之后最多剩余max_det个预测框if i.shape[0] > max_det:i = i[:max_det]#将该图片的检测结果存储到输出output中output[xi] = x[i]#如果处理此图片超出时间限制if (time.time() - t) > time_limit:#提示超时print(f'WARNING: NMS time limit {time_limit}s exceeded')break  #超时退出#返回值output是一个列表,存放每张图片的检测结果#eg:output[0]即第一张图片的检测结果,outout[0] shape=[预测框数量,6=xyxy+conf+cls]return output #返回每张图片的检测结果

真正作NMS过滤的代码是如下几行代码:

#Batches NMS
#这行代码是在多类别中应用NMS
#多类别NMS的处理策略是为了让每个类都能独立执行NMS,所以给所有预测框的坐标值添加一个偏移量
#偏移量仅取决于了类别的Id(也就是x[:, 5:6]),并且足够大,使得不同类的预测框不会重叠
c = x[:, 5:6] * (0 if agnostic else max_wh)#创建类别偏移c,即c=类别索引*max_whboxes = x[:, :4] + c#给每个预测框的坐标值加上类别偏移c,boxes shape=[预测框数量,4]
scores = x[:, 4]#获取所有预测框的置信度,scores shape=[预测框数量,]#执行NMS,i存放NMS之后的预测框id,shape=[NMS后的预测框数,]
i = torchvision.ops.nms(boxes, scores, iou_thres)

代码重点是在 '+c’这里,c是偏移量

(1)agnostic参数为True,表示所有类别一起作NMS处理,偏移量c0

(2)agnostic参数为False,表示按照不同类别分别作NMS处理,c=类别索引*max_wh,对不同类别的预测框做一个偏移操作,防止不同类别的预测框互相影响

注意:源码中是传入参数boxes、scores、iou_thres调用torchvision.ops.nms实现NMS处理,下面是NMS的代码实现。看了下面的NMS代码可以发现上面说agnostic为False时表示按照不同类别分别作NMS处理,但源码这里应该不是特别严格按不同类别作NMS(因为连类别的索引都没有用到),添加偏移量c只是算是一种trick把(我个人的理解,如有错误请指出)

代码流程: 

  1. 将所有预测框按置信度从高到低排序,确保置信度高的预测框排在前面。order存放排序后的预测框索引
  2. 从置信度最高的框开始(即order[0]),计算它和剩下所有预测框的IoU。剩下的预测框中IoU低于设定的IoU阈值则保留下来,高于IoU阈值的预测框则去除(即在order中删除当前预测框和IoU大于阈值的预测框索引)
  3. 重复步骤2,直到遍历完order中的预测框,得到最终筛选出来的预测框
import torch
def NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: IoU阈值'''x1 = boxes[:,0]y1 = boxes[:,1]x2 = boxes[:,2]y2 = boxes[:,3]#计算所有预测框的面积areas = (x2-x1)*(y2-y1)#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的IoUxx1 = x1[order[1:]].clamp(min=x1[i])yy1 = y1[order[1:]].clamp(min=y1[i])xx2 = x2[order[1:]].clamp(max=x2[i])yy2 = y2[order[1:]].clamp(max=y2[i])w = (xx2-xx1).clamp(min=0)h = (yy2-yy1).clamp(min=0)inter = w*hovr = inter/(areas[i] + areas[order[1:]] - inter)#当前预测框与剩下所有预测框的IoU值#筛选出IOU小于阈值的预测框索引, 过滤掉所有IOU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的IOU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)#实例
box = torch.tensor([[2, 3.1, 7, 5], [3, 4, 8, 4.8], [4, 4, 5.6, 7], [0.1, 0, 8, 1]])
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
output =NMS(boxes=box, scores=score, iou_thres=0.3)
print(output)

2.更换DIou-NMS

YOLOv5源码中使用的是IoU-NMS,这里可以作一下改进,将其替换为DIoU-NMS,因为DIoU考虑到的要素比IoU更多,应用于NMS中,可以使得NMS后得到的结果更加合理


第1步:编写DIoU_NMS函数

def DIoU_NMS(boxes,scores, iou_thres):'''boxes:shape=[预测框数量,4=xyxy],存放预测框坐标值scores:shape=[预测框数量,],存放预测框的置信度iou_thres: DIoU阈值'''#将预测框按置信度从高到低排序,order存放预测框的索引值_,order = scores.sort(0,descending=True)#keep保存NMS之后剩余的预测框索引keep = []while order.numel() > 0:#循环条件'''注意:当order=tensor([2,0,1,3])时,用order[0]可以正常取出第1个值2当order=tensor([3])时,用order[0]取出第1个值3会报错,需要用order.item()取出'''i = order[0] if order.numel()>1 else order.item()#取出置信度最大的预测框索引keep.append(i)#将预测框索引加入keep中#如果只剩余1个预测框,则NMS执行结束if order.numel() == 1:break#计算当前预测框与剩下所有预测框的DIoU#boxes[i,:]为当前预测框的坐标值,shape=[4,]#boxes[order[1:],:]为其他预测框的坐标值,shape=[n,4]ovr = bbox_iou(boxes[i, :], boxes[order[1:], :], DIoU=True)#筛选出DIoU小于阈值的预测框索引, 过滤掉所有DIoU大于阈值的预测框ids = (ovr<=iou_thres).nonzero().squeeze()#重置order数组,丢弃和当前bbox的DIoU大于阈值的预测框order = order[ids+1]#这里看代码会有点懵,可以debug一下#torch.LongTensor(keep)将keep列表转换为tensor,shape:[NMS后预测框数量,]return torch.LongTensor(keep)

这里计算DIoU的函数bbox_iou是直接引用了YOLOv5中的代码,该函数的实现在utils/metrics.py中,此函数集成了IoU、GIoU、DIoU、CIoU的计算,其他XIoU_NMS的实现方法类似。PS:GIoU、DIoU、CIoU用于损失函数的情况比较多

最后将DIoU_NMS函数复制到utils/general.py


第2步:将IoU-NMS更换为DIoU-NMS

utils/general.pynon_max_suppression函数的

i = torchvision.ops.nms(boxes, scores, iou_thres)

替换为

i = DIoU_NMS(boxes, scores, iou_thres)

这样就将IoU-NMS更换为DIoU-NMS了,但是我用几张图片作测试,发现大多数时候使用IoU-NMS和DIoU-NMS的处理结果是完全一致的。如下:

处理结果

所以这种改进可能实际意义不大

更换其他XIoU-NMS的方法是一样的,这里不再赘述

http://www.lryc.cn/news/536137.html

相关文章:

  • Android RenderEffect对Bitmap高斯模糊(毛玻璃),Kotlin(1)
  • 【linux学习指南】线程同步与互斥
  • JavaScript函数与方法详解
  • 【论文笔记】ZeroGS:扩展Spann3R+GS+pose估计
  • AtCoder - arc058_d Iroha Loves Strings解答与注意事项
  • 企业使用统一终端管理(UEM)工具提高端点安全性
  • Leetcode 算法题 9 回文数
  • 设计模式Python版 命令模式(上)
  • C语言之循环结构:直到型循环
  • 细说STM32F407单片机RTC的备份寄存器原理及使用方法
  • MATLAB计算反映热需求和能源消耗的度数日指标(HDD+CDD)(全代码)
  • J6 X8B/X3C切换HDR各帧图像
  • 09-轮转数组
  • 用vue3写一个好看的wiki前端页面
  • 瑞芯微烧写工具
  • 说下JVM中一次完整的GC流程?
  • Open FPV VTX开源之OSD使用分类
  • 智慧农业-虫害及生长预测
  • Python 识别图片和扫描PDF中的文字
  • C语言如何知道当前系统中的编译器数据类型的大小是多少?
  • gitlab Webhook 配置jenkins时“触发远程构建 (例如,使用脚本)”报错
  • Mysql中使用sql语句生成雪花算法Id
  • /etc/profile vs ~/.bashrc:如何正确使用?
  • SpringBoot实战:高效获取视频资源
  • Flutter_学习记录_数据更新的学习
  • c++ 多线程知识汇总
  • day09_实时类标签/指标
  • 【前端开发学习笔记16】Vue_9
  • Bash 中的运算方式
  • 2025年3月营销灵感日历