当前位置: 首页 > news >正文

yolov8obb角度预测原理解析

预测头

ultralytics/nn/modules/head.py

class OBB(Detect):"""YOLOv8 OBB detection head for detection with rotation models."""def __init__(self, nc=80, ne=1, ch=()):"""Initialize OBB with number of classes `nc` and layer channels `ch`."""super().__init__(nc, ch)self.ne = ne  # number of extra parametersc4 = max(ch[0] // 4, self.ne)self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""bs = x[0].shape[0]  # batch sizeangle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2)  # OBB theta logits# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]# angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]if not self.training:self.angle = anglex = Detect.forward(self, x)if self.training:return x, angle# return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

forward 输入值
在这里插入图片描述

self.cv4网路结构

ModuleList((0): Sequential((0): Conv((conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))(1): Sequential((0): Conv((conv): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))(2): Sequential((0): Conv((conv): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(1): Conv((conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)(act): SiLU(inplace=True))(2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1)))

angle维度14,1,8400

损失函数

pred_angle = pred_angle.permute(0, 2, 1).contiguous()
维度变为14 8400 1

将预测结果转为bboxes

pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)

计算回归损失

loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask)

这里的bbox_loss指的是:

self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)

接来下看RotatedBboxLoss

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):"""IoU loss."""weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum# DFL lossif self.use_dfl:target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weightloss_dfl = loss_dfl.sum() / target_scores_sumelse:loss_dfl = torch.tensor(0.0).to(pred_dist.device)return loss_iou, loss_dfl

两个旋转矩形如何计算IOU:

def probiou(obb1, obb2, CIoU=False, eps=1e-7):"""Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.Args:obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.Returns:(torch.Tensor): A tensor of shape (N, ) representing obb similarities."""x1, y1 = obb1[..., :2].split(1, dim=-1)x2, y2 = obb2[..., :2].split(1, dim=-1)a1, b1, c1 = _get_covariance_matrix(obb1)a2, b2, c2 = _get_covariance_matrix(obb2)t1 = (((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.25t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5t3 = (((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))/ (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)+ eps).log() * 0.5bd = (t1 + t2 + t3).clamp(eps, 100.0)hd = (1.0 - (-bd).exp() + eps).sqrt()iou = 1 - hdif CIoU:  # only include the wh aspect ratio partw1, h1 = obb1[..., 2:4].split(1, dim=-1)w2, h2 = obb2[..., 2:4].split(1, dim=-1)v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)with torch.no_grad():alpha = v / (v - iou + (1 + eps))return iou - v * alpha  # CIoUreturn iou
http://www.lryc.cn/news/390184.html

相关文章:

  • CICD之Git版本管理及基本应用
  • Python作用域及其应用
  • 谷歌上架,应用被Google play下架之后,活跃用户会暴跌?这是为什么?
  • web安全渗透测试十大常规项(一):web渗透测试之Fastjson反序列化
  • Unity 3D软件下载安装;Unity 3D游戏制作软件资源包获取!
  • PyTorch之nn.Module与nn.functional用法区别
  • 2024.06.24 校招 实习 内推 面经
  • 【C++】using namespace std 到底什么意思
  • 基于ESP32 IDF的WebServer实现以及OTA固件升级实现记录(三)
  • 116-基于5VLX110T FPGA FMC接口功能验证6U CPCI平台
  • Android - Json/Gson
  • 盲信号处理的发展现状
  • 二轴机器人装箱机:重塑物流效率,精准灵活,引领未来装箱新潮流
  • 使用python做飞机大战
  • Python面向对象编程:派生
  • 华为仓颉编程语言
  • 【微信小程序开发实战项目】——如何制作一个属于自己的花店微信小程序(2)
  • 解锁数据资产的无限潜能:深入探索创新的数据分析技术,挖掘其在实际应用场景中的广阔价值,助力企业发掘数据背后的深层信息,实现业务的持续增长与创新
  • Bridging nonnull in Objective-C to Swift: Is It Safe?
  • 算法训练 | 图论Part1 | 98.所有可达路径
  • 【JVM基础篇】垃圾回收
  • Spark join数据倾斜调优
  • YOLOv5初学者问题——用自己的模型预测图片不画框
  • 【linux学习---1】点亮一个LED---驱动一个GPIO
  • Redis分布式锁代码实现详解
  • Day01-02-gitlab
  • PyCharm远程开发配置(2024以下版本)
  • 解决Ucharts在小程序上的层级过高问题
  • 重保期间的网站安全防护:网站整站锁的应用与实践
  • Qt自定义类型