当前位置: 首页 > news >正文

动手学深度学习13.6. 目标检测数据集-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记,以及对课后练习的一些思考,自留回顾,也供同学之人交流参考。

本节课程地址:数据集_哔哩哔哩_bilibili

本节教材地址:13.6. 目标检测数据集 — 动手学深度学习 2.0.0 documentation

本节开源代码:…>d2l-zh>pytorch>chapter_optimization>object-detection-dataset.ipynb


目标检测数据集

目标检测领域没有像MNIST和Fashion-MNIST那样的小数据集。
为了快速测试目标检测模型,[我们收集并标记了一个小型数据集]。
首先,我们拍摄了一组香蕉的照片,并生成了1000张不同角度和大小的香蕉图像。
然后,我们在一些背景图片的随机位置上放一张香蕉的图像。
最后,我们在图片上为这些香蕉标记了边界框。

[下载数据集]

包含所有图像和CSV标签文件的香蕉检测数据集可以直接从互联网下载。

%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
#@save
d2l.DATA_HUB['banana-detection'] = (d2l.DATA_URL + 'banana-detection.zip','5de26c8fce5ccdea9f91267273464dc968d20d72')

读取数据集

通过read_data_bananas函数,我们[读取香蕉检测数据集]。
该数据集包括一个的CSV文件,内含目标类别标签和位于左上角和右下角的真实边界框坐标。

#@save
def read_data_bananas(is_train=True):"""读取香蕉检测数据集中的图像和标签"""data_dir = d2l.download_extract('banana-detection')csv_fname = os.path.join(data_dir, 'bananas_train' if is_trainelse 'bananas_val', 'label.csv')csv_data = pd.read_csv(csv_fname)csv_data = csv_data.set_index('img_name')images, targets = [], []for img_name, target in csv_data.iterrows():images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),# 其中所有图像都具有相同的香蕉类(索引为0)targets.append(list(target))# /256将像素坐标从[0,256]归一化到[0,1]return images, torch.tensor(targets).unsqueeze(1) / 256

通过使用read_data_bananas函数读取图像和标签,以下BananasDataset类别将允许我们[创建一个自定义Dataset实例]来加载香蕉检测数据集。

#@save
class BananasDataset(torch.utils.data.Dataset):"""一个用于加载香蕉检测数据集的自定义数据集"""def __init__(self, is_train):self.features, self.labels = read_data_bananas(is_train)print('read ' + str(len(self.features)) + (f' training examples' ifis_train else f' validation examples'))def __getitem__(self, idx):return (self.features[idx].float(), self.labels[idx])def __len__(self):return len(self.features)

最后,我们定义load_data_bananas函数,来[为训练集和测试集返回两个数据加载器实例]。对于测试集,无须按随机顺序读取它。

#@save
def load_data_bananas(batch_size):"""加载香蕉检测数据集"""train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)return train_iter, val_iter

让我们[读取一个小批量,并打印其中的图像和标签的形状]。
图像的小批量的形状为(批量大小、通道数、高度、宽度),看起来很眼熟:它与我们之前图像分类任务中的相同。
标签的小批量的形状为(批量大小, m m m,5),其中 m m m是数据集的任何图像中边界框可能出现的最大数量。

小批量计算虽然高效,但它要求每张图像含有相同数量的边界框,以便放在同一个批量中。
通常来说,图像可能拥有不同数量个边界框;因此,在达到 m m m之前,边界框少于 m m m的图像将被非法边界框填充。
这样,每个边界框的标签将被长度为5的数组表示。
数组中的第一个元素是边界框中对象的类别,其中-1表示用于填充的非法边界框。
数组的其余四个元素是边界框左上角和右下角的( x x x y y y)坐标值(值域在0~1之间)。
对于香蕉数据集而言,由于每张图像上只有一个边界框,因此 m = 1 m=1 m=1

batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
# 取出第一个batch
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
read 1000 training examples
read 100 validation examples(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))

[演示]

让我们展示10幅带有真实边界框的图像。
我们可以看到在所有这些图像中香蕉的旋转角度、大小和位置都有所不同。
当然,这只是一个简单的人工数据集,实践中真实世界的数据集通常要复杂得多。

# 因为每个像素值的范围是[0, 255],0是纯黑,255是纯白
# /255是将像素值从[0, 255]​归一化​​到[0, 1]
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):# * edge_size返回原尺寸d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

在这里插入图片描述

小结

  • 我们收集的香蕉检测数据集可用于演示目标检测模型。
  • 用于目标检测的数据加载与图像分类的数据加载类似。但是,在目标检测中,标签还包含真实边界框的信息,它不出现在图像分类中。

练习

  1. 在香蕉检测数据集中演示其他带有真实边界框的图像。它们在边界框和目标方面有什么不同?

解:
不同图像中的边界框的缩放比scale、宽高比ratio、旋转角度以及在图像中的位置不同,与背景图的色彩对比度也不同。

  1. 假设我们想要将数据增强(例如随机裁剪)应用于目标检测。它与图像分类中的有什么不同?提示:如果裁剪的图像只包含物体的一小部分会怎样?

解:
对于图像分类任务,仅需增强图像,且保证图像分类的主体可识别即可。
但对于目标检测任务,需要同时增强图像和边界框,并且需要同步调整边界框坐标,同时需要保证目标以及目标相关图像的完整性,否则可能导致数据增强后图像检测无效。因此,目标检测的数据增强要求更高。

http://www.lryc.cn/news/581962.html

相关文章:

  • DSP学习笔记2
  • 轨迹优化 | 基于激光雷达的欧氏距离场ESDF地图构建(附ROS C++仿真)
  • 7月份最新代发考试战报:思科CCNP 华为HCIP HCSP 售前售后考试成绩单
  • 网络安全之XSS漏洞:原理、危害与防御实践
  • 南柯电子|显示屏EMC整改:工业屏与消费屏的差异化策略
  • 接口漏洞怎么抓?Fiddler 中文版 + Postman + Wireshark 实战指南
  • 告别Root风险:四步构建安全高效的服务器管理体系
  • AJAX vs axios vs fetch
  • 【算法笔记】5.LeetCode-Hot100-矩阵专项
  • 腾讯云录音文件快速识别实战教程
  • Java后端技术博客汇总文档
  • 无人机声学探测模块技术分析!
  • 【C++开源库使用】使用libcurl开源库发送url请求(http请求)去下载用户头像文件(附完整源码)
  • RESTful API概念和设计原则
  • Ubunt20.04搭建GitLab服务器,并借助cpolar实现公网访问
  • 01、通过内网穿透工具把家中闲置电脑变成在线服务器
  • Java 大视界 -- 基于 Java 的大数据可视化在企业供应链动态监控与优化中的应用(336)
  • 迅为RK3568开发板基本工程目录-OpenHarmony APP工程结构
  • 上传Vue3+vite+Ts组件到npm官方库保姆级教程
  • 基于ArcGIS的洪水灾害普查、风险评估及淹没制图技术研究​
  • 【LeetCode 热题 100】206. 反转链表——(解法二)指针翻转
  • UE5详细保姆教程(第四章)
  • Post-Training on PAI (2):Ray on PAI,云上一键提交强化学习
  • 暑假算法日记第三天
  • C++笔记之开关控制的仿真与实际数据处理优雅设计
  • GNN--知识图谱(逐步贯通基础到项目实践)
  • 数学建模从入门到国奖——备赛规划优秀论文学习方法
  • 【HarmonyOS Next之旅】DevEco Studio使用指南(四十一) -> 获取自定义编译参数
  • 深入解析解释器模式:从理论到实践的完整指南
  • 浅学 Kafka