当前位置: 首页 > news >正文

torchvision中数据集的使用与DataLoader 小土堆pytorch记录

torchvision中数据集的使用

import torchvision
import tarfile
from torch.utils.tensorboard import SummaryWriter# 定义数据预处理流水线:只包含ToTensor转换
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()  # 将PIL图像转换为Tensor格式
])# CIFAR10数据集说明:
# - 包含6万张32x32彩色图片
# - 10个类别:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# - 训练集:50,000张,测试集:10,000张
train_set = torchvision.datasets.CIFAR10(root="./P_10_dataset",  # 数据集保存路径train=True,             # 加载训练集transform=dataset_transform,  # 应用转换download=True)          # 自动下载test_set = torchvision.datasets.CIFAR10(root="./P_10_dataset",train=False,            # 加载测试集transform=dataset_transform,download=True)# 如果需要手动解压(通常不需要,PyTorch会自动处理)
# with tarfile.open('P_10_dataset/cifar-10-python.tar.gz', 'r:gz') as tar:
#     tar.extractall(path='P_10_dataset')# 检查数据集样本:转换后是(tensor, 标签)元组
# print(test_set[0])  # 输出示例: (<Tensor>, 3)# 获取数据集类别名称
# print(test_set.classes)  # 输出10个类别的名称# 获取单个样本
# img, target = test_set[0]
# print(img)  # 输出Tensor对象
# print(target)  # 输出整数标签(如3)
# print(test_set.classes[target])  # 输出对应的类别名称(如'cat')# 验证转换是否成功
print(test_set[0])  # 确认输出为Tensor格式# 创建TensorBoard写入器
writer = SummaryWriter("p10")# 将测试集前10个样本写入TensorBoard
for i in range(10):img, target = test_set[i]  # 获取图像Tensor和标签writer.add_image("test_set", img, i)  # 添加到TensorBoard# 扩展:同时添加类别标签作为标题# writer.add_image(f"test_set/{test_set.classes[target]}", img, i)writer.close()  # 关闭写入器
1. torchvision.datasets 模块
  • 作用:提供常用计算机视觉数据集的便捷访问

  • 常用数据集

    • CIFAR10/100:小尺寸彩色图像分类

    • MNIST/FashionMNIST:手写数字/服装灰度图像

    • ImageNet:大规模图像分类(需单独下载)

    • COCO:目标检测与分割数据集

  • 核心参数

torchvision.datasets.XXX(root, train, transform, download)
  • root:数据集存储路径

  • train:True=训练集,False=测试集

  • transform:数据预处理流水线

  • download:自动下载数据集

CIFAR10数据集,有6w张彩色图片,5W张用作train,1w张用于test
类别包含以下10个类别,每个类别6000张
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

DataLoader的使用

视频代码

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter# 准备测试集:使用CIFAR10数据集,应用ToTensor转换
test_data = torchvision.datasets.CIFAR10(root="./P_10_dataset",   # 数据集存储路径train=False,             # 使用测试集transform=torchvision.transforms.ToTensor()  # 将PIL图像转为Tensor
)# 创建DataLoader
test_loader = DataLoader(dataset=test_data,       # 要加载的数据集batch_size=64,           # 每批加载的样本数shuffle=True,            # 每个epoch是否打乱数据顺序num_workers=0,           # 数据加载使用的子进程数(0表示主进程)drop_last=True           # 是否丢弃最后不足batch_size的批次
)# 从测试数据集中取出第一个样本
img, target = test_data[0]
print(img.shape)  # 输出: torch.Size([3, 32, 32]) - 3通道,32x32大小
print(target)     # 输出: 3 - 对应的类别标签# 创建TensorBoard写入器
writer = SummaryWriter('dataloader_logs')
step = 0  # 步数计数器# 进行2个epoch的迭代
for epoch in range(2):  # epoch数# 遍历DataLoader中的所有批次for data in test_loader:imgs, targets = data  # 解包批次数据# 将当前批次的所有图像写入TensorBoardwriter.add_images("Epoch:{}".format(epoch), imgs, step)step += 1  # 增加步数# 关闭写入器
writer.close()# 批次数据形状示例: 
#   imgs: torch.Size([64, 3, 32, 32]) 
#   targets: tensor([5, 6, 5, 0, ...]) - 64个标签
1. DataLoader 核心功能
功能说明
批量加载将数据集分成多个批次(batch)
数据打乱每个epoch重新随机排序数据
并行加载使用多进程加速数据加载
自动分批处理最后不足batch大小的批次
2. DataLoader 重要参数
参数作用常用值
batch_size每批样本数32/64/128
shuffle是否打乱数据True(训练)/False(测试)
num_workers加载数据的子进程数0(主进程)/2/4/8
drop_last是否丢弃最后不足batch的样本True/False
pin_memory是否将数据复制到CUDA固定内存True(GPU训练)
3. 批次数据结构
  • 图像数据(batch_size, channels, height, width)

    • 示例: [64, 3, 32, 32] 表示64张32x32的RGB图像

  • 标签数据(batch_size,)

    • 示例: tensor([3, 5, 9, ...]) 64个类别标签

4. TensorBoard可视化
  • add_image(): 添加单张图像

  • add_images(): 添加多张图像(整个批次)

  • 命名技巧: 包含epoch信息便于区分不同训练阶段

DataLoader是PyTorch数据管道的核心组件,它不仅仅是简单的数据分批工具,更是连接数据预处理与模型训练的桥梁。合理配置DataLoader参数可以显著提升训练效率和资源利用率,特别是在处理大规模数据集时。

http://www.lryc.cn/news/621332.html

相关文章:

  • # Vue 列表渲染详解
  • VLMs开发——基于Qwen2.5-VL 实现视觉语言模型在目标检测中的层级结构与实现方法
  • RxJava Android 创建操作符实战:从数据源到Observable
  • 中久数创——笔试题
  • PiscTrace基于YOLO追踪算法的物体速度检测系统详解
  • 2025.8.24复习总结
  • React.memo、useMemo 和 React.PureComponent的区别
  • 基于场景的无人驾驶叉车分类研究:应用场景与技术选型分析
  • springboot myabtis返回list对象集合,对象的一个属性为List对象
  • 飞算 JavaAI 真是 yyds
  • 一周学会Matplotlib3 Python 数据可视化-绘制面积图(Area)
  • [C++] Git 使用教程(从入门到常用操作)
  • TDengine IDMP 基本功能(6. 无问智推)
  • TDengine IDMP 基本功能(7. 智能问数)
  • C++11新特性深度解析
  • 【CF】Day127——杂题 (数论 gcd | 数论 gcd | 博弈论 | 二分图判断 | 贪心 + 暴力 / 二分答案 | 数论 gcd + 动态规划)
  • OSG+Qt —— 笔记1 - Qt窗口加载模型(附源码)
  • Mybatis 源码解读-SqlSession 会话源码和Executor SQL操作执行器源码
  • 《Python函数:从入门到精通,一文掌握函数编程精髓》
  • Transformer网络结构解析
  • 《嵌入式 C 语言编码规范与工程实践个人笔记》参考华为C语言规范标准
  • CNN - 卷积层
  • GaussDB数据库架构师修炼(十六) 如何选择磁盘
  • 《算法导论》第 24 章 - 单源最短路径
  • 20250814 最小生成树总结
  • Java 大视界 -- Java 大数据机器学习模型在金融欺诈检测与防范策略制定中的应用(397)
  • 【Demo】AI-ModelScope/bert-base-uncase 模型训练及使用
  • 市面上有没有可以导入自有AI算法模型的低空平台?
  • pytorch学习笔记-Loss的使用、在神经网络中加入Loss、优化器(optimizer)的使用
  • Linux 对 YUM 包的管理