当前位置: 首页 > news >正文

DAY 38 Dataset和Dataloader类

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

一、Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.MNIST本质上集成了torch.utils.data.Dataset,所以自然需要有对应的方法。

import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

__getitem__方法

# 示例代码
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__方法

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

# minist数据集的简化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签self.transform = transform # 预处理操作def __len__(self): return len(self.data)  # 返回样本总数def __getitem__(self, idx): # 获取指定索引的样本# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理(如ToTensor、Normalize)if self.transform is not None: # 如果有预处理操作img = self.transform(img) # 转换图像格式# 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化return img, target  # 返回处理后的图像和标签
# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)

二、Dataloader类

# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

@浙大疏锦行

http://www.lryc.cn/news/573340.html

相关文章:

  • Python编程语言:2025年AI浪潮下的技术统治与学习红利
  • Python UDP Socket 实时在线刷卡扫码POS消费机服务端示例源码
  • 自动化立体仓库堆垛机控制系统STEP7 FC3功能块 I/O映射
  • `provide` 和 `inject` 组件通讯:实现跨组件层级通讯
  • 机器学习15-XGBoost
  • 微服务拆分——nacos/Feign
  • 华为云Flexus+DeepSeek征文 | 基于Flexus X实例的金融AI Agent开发:智能风控与交易决策系统
  • 李宏毅2025《机器学习》第三讲-AI的脑科学
  • 蓝牙数据通讯,实现内网电脑访问外网电脑
  • WPF调试三种工具介绍:Live Visual Tree、Live Property Explorer与Snoop
  • SylixOS 下的消息队列
  • Jupyter notebook调试:设置断点运行
  • Redis后端的简单了解与使用(项目搭建前置)
  • DeepEP开源MoE模型分布式通信库
  • 洛谷P3953 [NOIP 2017 提高组] 逛公园
  • 【DCS开源项目】—— Lua 如何调用 DLL、DLL 与 DCS World 的交互
  • day44-硬件学习之arm启动代码
  • 【Datawhale组队学习202506】零基础学爬虫 02 数据解析与提取
  • 【simulink】IEEE5节点系统潮流仿真模型(2机5节点全功能基础模型)
  • 【智能体】dify部署本地步骤
  • LeetCode第279题_完全平方数
  • 湖北理元理律师事务所企业债务纾困路径:司法重整中的再生之道
  • 蓝桥杯备赛篇(上) - 参加蓝桥杯所需要的基础能力 1(C++)
  • 华为OD机试_2025 B卷_判断一组不等式是否满足约束并输出最大差(Python,100分)(附详细解题思路)
  • 车载电子电器架构 --- 电子电气架构设计方案
  • QC -io 服务器排查报错方式/报错: Failed to convert string to integer of varId variable!“
  • 2.7 Python方法调用机制解析:从描述符到字节码执行
  • 学习C++、QT---03(C++的输入输出、C++的基本数据类型介绍)
  • 【无标题】使用 Chocolatey 安装 WSL 管理工具 LxRunOffline
  • 贪心算法思路详解