Dataset类案例 小土堆Pytorch入门视频记录
最近学完了python基础,过了遍机器学习理论基础和深度学习神经网络的部分,开始看小土堆的视频入门torch,简单记录下。
核心知识点
1. Dataset 的基本概念
Dataset 是 PyTorch 提供的一个抽象类,用于表示数据集
自定义数据集需要继承
torch.utils.data.Dataset
类必须实现三个方法:
__init__
,__getitem__
,__len__
2. 自定义 Dataset 的实现
from torch.utils.data import Dataset
from PIL import Image
import osclass MyData(Dataset):def __init__(self, root_dir, label_dir):self.root_dir = root_dir # 数据集根目录self.label_dir = label_dir # 类别标签目录self.path = os.path.join(root_dir, label_dir) # 拼接完整路径self.img_path = os.listdir(self.path) # 获取所有图片文件名def __getitem__(self, idx):img_name = self.img_path[idx] # 获取指定索引的图片名img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path) # 读取图片label = self.label_dir # 获取标签return img, label # 返回图片和标签def __len__(self):return len(self.img_path) # 返回数据集大小root_dir = "dataset/train"
# 创建蚂蚁数据集
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)
# 创建蜜蜂数据集
bees_label_dir = "bees"
bees_dataset = MyData(root_dir, bees_label_dir)
# 合并数据集(简单示例)
train_dataset = ants_dataset + bees_dataset
在pycharm的python console中逐块运行可以更好的看到各元素的类型
__init__
: 初始化,设置路径和加载文件列表__getitem__
: 根据索引返回单个数据样本(图片+标签)__len__
: 返回数据集大小
掌握 Dataset 的使用是 PyTorch 数据处理的基础,后续的 DataLoader 和模型训练都依赖于良好的数据集实现。