当前位置: 首页 > news >正文

目标检测 YOLOv5 - 推理时的数据增强

目标检测 YOLOv5 - 推理时的数据增强

flyfish

版本 YOLOv5 6.2

参考地址

https://github.com/ultralytics/yolov5/issues/303

在训练时可以使用数据增强,在推理阶段也可以使用数据增强
在测试使用数据增强有个名字叫做Test-Time Augmentation (TTA)

实际使用中使用了大中小三个不同分辨率,中间大小分辨率的图像进行了左右反转
大分辨率
480 * 640 宽度W 高度H 比例为1
在这里插入图片描述
中分辨率
416 * 544 宽度W 高度H 比例为0.83

在这里插入图片描述
小分辨率
352 * 448 宽度W 高度H 比例为0.67

在这里插入图片描述

命令

python detect.py --weights ./yolov5s.pt --source ./data/images/bus.jpg  --imgsz 640 --augment

--augment语法
推理时默认不使用增强

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="increase output verbosity",action="store_true")
args = parser.parse_args()
if args.verbose:print("verbosity turned on")
else:print("verbosity turned off")

假如上段代码是test.py

# python test.py
# 输出     verbosity turned off# python test.py -v
# 输出 verbosity turned on

验证图像大小是每个维度上的stride的倍数,默认是32的倍数
例如 图像大小是1111 那么就是
--img-size [1111, 1111] 更新为 [1120, 1120]

def check_img_size(imgsz, s=32, floor=0):# Verify image size is a multiple of stride s in each dimensionif isinstance(imgsz, int):  # integer i.e. img_size=640new_size = max(make_divisible(imgsz, int(s)), floor)else:  # list i.e. img_size=[640, 480]imgsz = list(imgsz)  # convert to list if tuplenew_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]if new_size != imgsz:LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')return new_size

推理增强部分

def _forward_augment(self, x):img_size = x.shape[-2:]  # height, widths = [1, 0.83, 0.67]  # scalesf = [None, 3, None]  # flips (2-ud, 3-lr)y = []  # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))print("xi.shape[2:]:",xi.shape[2:])yi = self._forward_once(xi)[0]  # forwardprint("0 yi:",yi.shape)#cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # saveyi = self._descale_pred(yi, fi, si, img_size)print("1 yi.shape:",yi.shape)y.append(yi)y = self._clip_augmented(y)  # clip augmented tailsreturn torch.cat(y, 1), None  # augmented inference, traindef _descale_pred(self, p, flips, scale, img_size):# de-scale predictions following augmented inference (inverse operation)if self.inplace:p[..., :4] /= scale  # de-scaleif flips == 2:p[..., 1] = img_size[0] - p[..., 1]  # de-flip udelif flips == 3:p[..., 0] = img_size[1] - p[..., 0]  # de-flip lrelse:x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale  # de-scaleif flips == 2:y = img_size[0] - y  # de-flip udelif flips == 3:x = img_size[1] - x  # de-flip lrp = torch.cat((x, y, wh, p[..., 4:]), -1)return pdef _clip_augmented(self, y):# Clip YOLOv5 augmented inference tailsnl = self.model[-1].nl  # number of detection layers (P3-P5)g = sum(4 ** x for x in range(nl))  # grid pointse = 1  # exclude layer counti = (y[0].shape[1] // g) * sum(4 ** x for x in range(e))  # indicesy[0] = y[0][:, :-i]  # largei = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indicesy[-1] = y[-1][:, i:]  # smallreturn y

关于翻转看

if self.inplace:p[..., :4] /= scale  # de-scaleif flips == 2:p[..., 1] = img_size[0] - p[..., 1]  # de-flip udelif flips == 3:p[..., 0] = img_size[1] - p[..., 0]  # de-flip lr

2表示上下翻转
3表示左右翻转
s = [1, 0.83, 0.67] 是缩放比例,且能被32整除

这里的顺序是HW

xi.shape[2:]: torch.Size([640, 480])
xi.shape[2:]: torch.Size([544, 416])
xi.shape[2:]: torch.Size([448, 352])yi.shape: torch.Size([1, 18900, 85])
yi.shape: torch.Size([1, 13923, 85])
yi.shape: torch.Size([1, 9702, 85])

合并去冗余之后再进NMS

torch.Size([1, 34233, 85])

原来推理一张图像,增强后是推理3张

http://www.lryc.cn/news/273769.html

相关文章:

  • 篇二:springboot2.7 OAuth2 server使用jdbc存储RegisteredClient
  • 卷积神经网络|导入图片
  • 关于unity的组件VerticalLayoutGroup刷新显示不正常的问题
  • wait 和 notify 这个为什么要在synchronized 代码块中?
  • 大白话说区块链和通证
  • Jvm之垃圾收集器(个人见解仅供参考)
  • Minitab 21软件安装包下载及安装教程
  • Java版商城:Spring Cloud+SpringBoot b2b2c电子商务平台,多商家入驻、直播带货及免 费 小程序商城搭建
  • 阿里云被拉入黑洞模式怎么办?该怎么换ip-速盾网络
  • Android 13.0 recovery竖屏界面旋转为横屏
  • 异地环控设备如何远程维护?贝锐蒲公英解决远程互联难题
  • flutter 判断是否是web环境
  • 视频智能分析/云存储平台EasyCVR接入海康SDK,通道名称未自动更新该如何解决?
  • 后端开发——JDBC的学习(三)
  • Redis 生产环境查找无过期时间的 key
  • Visual Studio 2017编译Python3.8.18源码
  • 【mujoco】Ubuntu20.04中解决mujoco报错raise error.MujocoDependencyError
  • 机器学习的三个方面
  • 关于一名资深Java程序员在移动端的进阶之路
  • clickonce excel 插件发布安装的原理
  • 关于MySQL Cluster
  • 牵绳遛狗你我他文明家园每一天,助力共建文明社区,基于YOLOv7开发构建公共场景下未牵绳遛狗检测识别系统
  • 命令行艺术:简洁指南,效率倍增 | 开源日报 No.136
  • python基础教程五(字典概念和基本操作)
  • 【Delphi 基础知识 11】重载函数的使用
  • 经典目标检测YOLO系列(一)YOLOV1的复现(1)总体架构
  • 《设计模式》之策略模式
  • Django文章标签推荐
  • Git、TortoiseGit进阶
  • 山区老人爱的礼物丨守护银龄,情暖寒冬