当前位置: 首页 > news >正文

YoloV8修改分类(Classify)的前处理(记录)

修改原因

  • yolo自带的分类前处理对于长方形的数据不够友好,存在特征丢失等问题
  • 修改后虽然解决了这个问题但是局部特征也会丢失因为会下采样程度多于自带的,总之具体哪种好不同数据应该表现不同
  • 我的数据中大量长宽比很大的数据所以尝试修改自带的前处理,以保证理论上的合理性。
修改过程
  1. yolo中自带的分类前处理和检测有一些差异

调试推理代码发现ultralytics/models/yolo/classify/predict.py中对图像进行前处理的操作主要是self.transforms

def preprocess(self, img):"""Converts input image to model-compatible data type."""if not isinstance(img, torch.Tensor):is_legacy_transform = any(self._legacy_transform_name in str(transform) for transform in self.transforms.transforms)if is_legacy_transform:  # to handle legacy transformsimg = torch.stack([self.transforms(im) for im in img], dim=0)else:# import ipdb;ipdb.set_trace()img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0)img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32

通过调试打印self.transforms得到

Compose(Resize(size=96, interpolation=bilinear, max_size=None, antialias=True)CenterCrop(size=(96, 96))ToTensor()Normalize(mean=tensor([0., 0., 0.]), std=tensor([1., 1., 1.]))
)

假设我设置的imgsz为96,从这里简单的解读可以理解为先进行resize然后进行中心裁切保证输入尺寸为96x96

具体的查看哪里可以修改前处理,首先发现在ultralytics/engine/predictor.py中
def setup_source(self, source):"""Sets up source and inference mode."""self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size# import ipdb; ipdb.set_trace()self.transforms = (getattr(self.model.model,"transforms",classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), #dujiang)if self.args.task == "classify"else None)

可以发现self.transforms主要调用的是classify_transforms方法

进一步我们在ultralytics/data/augment.py中找到classify_transforms的实现
if scale_size[0] == scale_size[1]:# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]else:# Resize the shortest edge to matching target dim for non-square targettfl = [T.Resize(scale_size)]tfl.extend([T.CenterCrop(size),T.ToTensor(),T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),])

发现和我们的设想基本一致,查看代码逻辑首先是针对正方形图像会将图像缩放到指定的高度,同时保持长宽比,确保较短的一边正好等于目标尺寸,非正方形图片将短边resize到指定大小,长边此时可能是超出的,所以 T.CenterCrop(size)进行中心裁切确保尺寸是我们指定的

针对上面的分析可能问题就很明显了,如果处理的图像是长宽比非常不均匀的图像,那么中心裁切会导致丢失大量信息,我参考了检测的方法,决定将分类的预处理修改为填充而不是裁切

  • 首先确定思想,我想做的是根据长边resize到指定尺寸并且保证长宽比,短边会不足,刚好与原本的代码逻辑相反
  • 然后短边不足的地方进行填充保证短边也达到指定尺寸(填充yolo好像一般是144,这里我也选择144)
  • 具体实现如下
  1. 添加两个类分别实现resizepadding
class ResizeLongestSide:def __init__(self, size, interpolation):self.size = sizeself.interpolation = interpolationdef __call__(self, img):# 获取图像的当前尺寸width, height = img.size# 计算缩放比例if width > height:new_width = self.sizenew_height = int(self.size * height / width)else:new_height = self.sizenew_width = int(self.size * width / height)# 按长边缩放return img.resize((new_width, new_height), Image.BILINEAR)class PadToSquare:def __init__(self, size, fill=(114)):self.size = sizeself.fill = filldef __call__(self, img):# 获取当前尺寸width, height = img.size# 计算需要填充的大小delta_w = self.size - widthdelta_h = self.size - heightpadding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))# 填充图像return F.pad(img, padding, fill=self.fill, padding_mode='constant')
  1. 调用上面的类进行实现
def classify_transforms(size=96,mean=DEFAULT_MEAN,std=DEFAULT_STD,interpolation="BILINEAR",crop_fraction: float = DEFAULT_CROP_FRACTION,padding_color=(114, 114, 114),  # 默认填充为灰色
):import torchvision.transforms as Timport torchfrom torchvision.transforms import functional as F# import ipdb;ipdb.set_trace()tfl = [# T.ClassifyLetterBox(size),ResizeLongestSide(size, interpolation=getattr(T.InterpolationMode, interpolation)),  # 按长边缩放PadToSquare(size, fill=padding_color),  # 填充至正方形T.ToTensor(),T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),]return T.Compose(tfl)
  1. 想要训练前先确定自己修改是否符合预期进行如下测试
Examples:>>> from ultralytics.data.augment import LetterBox, classify_transforms, classify_transforms_with_padding>>> from PIL import Image>>> transforms = classify_transforms_with_padding(size=96)>>> img = Image.open('bus.jpg')  3ch img_rgb = Image.merge('RGB', (img, img, img))>>> transformed_img = transforms(img)>>>import torchvision.transforms as T>>>DEFAULT_MEAN = (0.0, 0.0, 0.0)>>>DEFAULT_STD = (1.0, 1.0, 1.0)>>>import torch>>>def save_transformed_image(transformed_img, save_path="transformed_image.png"):# 定义反向变换,将张量转换回 PIL 图像unnormalize = T.Normalize(mean=[-m / s for m, s in zip(DEFAULT_MEAN, DEFAULT_STD)],std=[1 / s for s in DEFAULT_STD])img_tensor = unnormalize(transformed_img)img_tensor = torch.clamp(img_tensor, 0, 1)to_pil = T.ToPILImage()img_pil = to_pil(img_tensor)img_pil.save(save_path)print(f"Image saved at {save_path}")>>>save_transformed_image(transformed_img, save_path="transformed_image.png")
  1. 效果图
    请添加图片描述
    请添加图片描述
  2. ok,效果预期一致,接下来可以训练了,之前对于矩形的图像会有裁切现在使用padding解决了。但是具体效果还得看结果。
  3. 补充一下修改一定要把类和方法分开,即不要在方法中定义类,这样会导致训练出错
总结中间遇到问题参考这里解决
http://www.lryc.cn/news/434463.html

相关文章:

  • 半监督学习能否帮助训练更好的模型?
  • VBA 获取字段标题代码轻松搞定
  • C++代码片段
  • Golang | Leetcode Golang题解之第388题文件的最长绝对路径
  • docker打包前端项目
  • 调度器怎么自己写?调度器在实现时需要注意哪些细节?请写一个jvm的调度器?如何在这个调度器中添加多个任务?
  • 创客匠人对话|德国临床营养学家单场发售百万秘笈大公开
  • 开源项目低代码表单FormCreate从Vue2到Vue3升级指南
  • 序偶解释:李冬梅老师书线性表一章第一页
  • 3GPP协议入门——物理层基础(二)
  • Java学习Day41:手刃青背龙!(spring框架之事务)
  • el-image(vue 总)
  • 餐饮「收尸人」,血亏奶茶店……
  • 【Python进阶】学习Python从入门到进阶,详细步骤,就看这一篇。文末附带项目演练!!!
  • OpenCV结构分析与形状描述符(9)检测轮廓相对于其凸包的凹陷缺陷函数convexityDefects()的使用
  • HTTP 之 响应头信息(二十三)
  • 智能风扇的全新升级:NRK3603语音芯片识别控制模块的应用
  • 如何通过pSLC技术实现性能与容量的双赢
  • 减速电机的基本结构及用料简介
  • 1688跨境电商接口开放接入,跨境电商的尽头到底谁在赚钱?
  • SpringBoot 增量部署发布
  • java八股!1
  • 【学术会议征稿】2024年智能驾驶与智慧交通国际学术会议(IDST 2024)
  • 2024最全网络安全工程师面试题(附答案)
  • 828华为云征文| 华为云 Flexus X 实例:引领云计算新时代的柔性算力先锋
  • 何时何地,你需要提示工程、函数调用、RAG还是微调大模型?
  • three.js线框模式
  • VScode 的简单使用
  • 五星级可视化页面(07):城市交通方向,城市畅通的保障。
  • 贪心+构造,1924A - Did We Get Everything Covered?