当前位置: 首页 > news >正文

paddle.vision 与 torchvision 中的box NMS使用方式

torchvision 中有多个用于计算 BBox NMS 的 API, 在本篇氵文中, 使用

torchvision.ops.boxes.batched_nms

paddle.vision 中通过 paddle.vision.ops.nms 来进行多个 Box 的 NMS 操作

1. torchvision 中 batched_nms 操作

torchvision batched_nms

def batched_nms(boxes: torch.Tensor,scores: torch.Tensor,idxs: torch.Tensor,iou_threshold: float,
) -> torch.Tensor

传入的参数分别为

  • 边界框boxes, 格式[x1, y1, x2, y2],shape 为 [num, 4],dtype 为 float
  • 置信度scores, shape 为 [num,],dtype 为 float
  • 类别idxs, shape 为 [num,],dtype 为 int

来举个例子:

import numpy as np
import torch
from torchvision.ops import boxes as box_opsseed = 1107
iou_threshold = 0.35
box_num = 100000
cls_num = 80np.random.seed(seed)boxes = np.random.rand(box_num, 4).astype("float32")
boxes = torch.from_numpy(boxes)scores = np.random.rand(box_num).astype("float32")
scores = torch.from_numpy(scores)idxs = np.random.randint(0, cls_num, size=(box_num,))
idxs = torch.from_numpy(idxs)assert boxes.shape[-1] == 4keep = box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold)

2. paddle.vision.ops.nms 操作

paddle.vision.ops.nms(boxes, iou_threshold=0.3, scores=None, category_idxs=None, categories=None, top_k=None)

boxesiou_thresholdscorescategory_idxs 等参数和上述 torchvision 中 batched_nms 参数一样
不同的是 paddle 中还需要 categories 参数,(其实没什么必要)

category_idxs 是每个 bbox 的类别,而 categories 是一共的类别

比如 COCO 一共80类,则:

categories = paddle.arange(80)

Paddle 中的例子:

import numpy as np
import paddleseed = 1107
iou_threshold = 0.35
box_num = 100000
cls_num = 80np.random.seed(seed)boxes = np.random.rand(box_num, 4).astype("float32")
boxes = paddle.to_tensor(boxes)scores = np.random.rand(box_num).astype("float32")
scores = paddle.to_tensor(scores)idxs = np.random.randint(0, cls_num, size=(box_num,))
idxs = paddle.to_tensor(idxs)cls_list = paddle.arange(0, cls_num)assert boxes.shape[-1] == 4keep = paddle.vision.ops.nms(boxes, iou_threshold, scores, idxs, cls_list)
http://www.lryc.cn/news/20591.html

相关文章:

  • php mysql校园帮忙领取快递平台
  • C/C++开发,无可避免的内存管理(篇二)-约束好跳脱的内存
  • 【Java】让我们对多态有深入的了解(九)
  • 12 个适合做外包项目的开源后台管理系统
  • 鼠标更换指针图案和更改typora的主题
  • 【洛谷 P1563】[NOIP2016 提高组] 玩具谜题(模拟+结构体数组+指针)
  • 阿里测试经验7年,从功能测试到自动化测试,我整理的超全学习指南
  • Educational Codeforces Round 143 (Rated for Div. 2)
  • 业务代码编写过程中如何「优雅的」配置隔离
  • English Learning - L2-2 英音地道语音语调 2023.02.23 周四
  • java:线程等待与唤醒 - Object的wait()和notify()
  • 实现弹窗功能并修改其中一个系数
  • vue-draggable浏览器拖拽event事件对象拖动时 DragEvent path undefined
  • 【云原生】搭建k8s高可用集群—20230225
  • LeetCode121_121. 买卖股票的最佳时机
  • 收割不易,五面Alibaba终拿Java岗offer
  • 【离线数仓-4-数据仓库设计-分层规划构建流程】
  • SQL零基础入门学习(十一)
  • 排序基础之插入排序
  • LabVIEW控制DO通道输出一个精确定时的数字波形
  • openpnp - 零碎记录
  • Qt编写微信支付宝支付
  • LeetCode 剑指 Offer 64. 求1+2+…+n
  • Mapper代理开发
  • 为什么在连接mysql时,设置 SetConnMaxIdleTime 没有作用
  • 嵌入式开发利器
  • Qt 的QString类的使用
  • django项目部署(腾讯云服务器centos)
  • 计算机网络笔记、面试八股(一)——TCP/IP网络模型
  • 51单片机入门 - 简短的位运算实现扫描矩阵键盘