当前位置: 首页 > news >正文

Day 38: Dataset类和DataLoader类

核心概念

在处理大规模数据集时,显存往往无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch提供了两个关键类来解决这个问题:

  1. DataLoader类:决定数据如何加载
  2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理

实战演练:MNIST数据集

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)

2. 数据预处理

# 数据预处理管道
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的标准化参数
])

3. 加载MNIST数据集

# 加载训练集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)# 加载测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

🔧 Dataset类详解

Dataset类的核心方法

PyTorch的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • __len__():返回数据集的样本总数
  • __getitem__(idx):根据索引idx返回对应样本的数据和标签

魔术方法示例

# __getitem__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30
# __len__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 使用len()函数获取元素数量,这会自动调用__len__方法
my_list_obj = MyList()
print(len(my_list_obj))  # 输出:5

查看单个样本

# 获取一个样本
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx]
print(f"Label: {label}")# 可视化图像
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')plt.show()imshow(image)

DataLoader类详解

DataLoader负责将Dataset中的数据按批次加载,并提供多种数据加载策略:

# 创建训练数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,    # 每个批次64张图片shuffle=True      # 随机打乱数据
)# 创建测试数据加载器
test_loader = DataLoader(test_dataset,batch_size=1000   # 每个批次1000张图片# shuffle=False   # 测试时不需要打乱数据
)

Dataset vs DataLoader 对比

维度DatasetDataLoader
核心职责定义"数据是什么"和"如何获取单个样本"定义"如何批量加载数据"和"加载策略"
核心方法__getitem____len__无自定义方法,通过参数控制
预处理位置__getitem__中通过transform执行无预处理逻辑
并行处理无(仅单样本处理)支持多进程加载
典型参数roottransformbatch_sizeshufflenum_workers

总结

Dataset类的职责

  • 数据内容定义:数据存储路径、读取方式
  • 预处理逻辑:图像变换、数据增强等
  • 返回格式:如(image_tensor, label)

DataLoader类的职责

  • 批量处理:控制batch_size
  • 数据打乱:shuffle参数
  • 并行加载:num_workers参数
  • 内存管理:防止一次性加载过多数据

实用技巧

  1. batch_size选择:通常选择2的幂次方(32、64、128等),这与GPU计算效率相关
  2. 数据预处理时机:在Dataset的__getitem__方法中进行,而不是DataLoader中
  3. 内存优化:DataLoader的num_workers参数可以开启多进程加载,提高效率

@浙大疏锦行

http://www.lryc.cn/news/618526.html

相关文章:

  • Labelme从安装到标注:零基础完整指南
  • 【完美解决】在 Ubuntu 24.04 上为小米 CyberDog 2 刷机/交叉编译:终极 Docker 环境搭建指南
  • mimiconda+vscode
  • HeidiSQL 连接 MySQL 报错 10061
  • vue+Django农产品推荐与价格预测系统、双推荐+机器学习预测+知识图谱
  • 跨界重构规则方法论
  • ubuntu24下keychorn键盘连接不了的改建页面的问题修复
  • 深入理解哈希结构及其应用
  • secureCRT ymodem协议连续传输文件速率下降
  • 鸿蒙开发教程实战案例源码分享-好看的SwitchButton
  • [SC]SystemC中的SC_FORK和SC_JOIN用法详细介绍
  • 17、CryptoMamba论文笔记
  • 42.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--扩展功能--集成网关--网关集成认证(一)
  • UNet改进(32):结合CNN局部建模与Transformer全局感知
  • Day45--动态规划--115. 不同的子序列,583. 两个字符串的删除操作,72. 编辑距离
  • DeepSeek-R1-0528 推理模型完整指南:领先开源推理模型的运行平台与选择建议
  • XC7A15T-1FTG256C Xilinx AMD Artix-7 FPGA
  • Linux中Apache与Web之虚拟主机配置指南
  • git config的配置全局或局部仓库的参数: local, global, system
  • 【unity实战】使用Splines+DOTween制作弯曲手牌和抽牌动画效果
  • 有限元方法中的数值技术:行列式、求逆、矩阵方程
  • 【bug 解决】串口输出字符乱码的问题
  • 【Datawhale夏令营】多模态RAG学习
  • 【Bug经验分享】由jsonObject-TypeReference引发的序列化问题
  • 【昇腾】关于Atlas 200I A2加速模块macro0配置3路PCIE+1路SATA在hboot2中的一个bug_20250812
  • STM32_bug总结(TIM定时中断进不去和只进1次)
  • 高性能web服务器Nginx
  • 【Android】【bug】Json解析错误Expected BEGIN_OBJECT but was STRING...
  • linux 开机进入initramfs无法开机
  • 跨设备开发不再难:HarmonyOS 分布式任务管理应用全解析