当前位置: 首页 > news >正文

Pytorch构建自己的数据集

1.Pytorch内置的Dataset

Pytorch中内置了许多数据集,我们可以从torchvision库中进行导入。比如,我们可以导入Fashion-MNIST数据集

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

但如果torchvision库中没有该数据集,我们需要自己构建一个。
其中一个方法就是把构建好的数据集使用torch.utils.data.TensorDataset()封装以下,然后再传入torch.utils.data.DataLoader

trainloader  =  torch.utils.data.DataLoader(training_data, batch_size=32, shuffle=True)
testloader  =  torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False)

但是如果自己写一个类的话会高达上一些,嘻嘻。下面看看如何自己构建一个Dataset Class。

2.Build Custom Dataset

构建一个Custom Dataset需要继承``三个函数__init__, __len__, 和 __getitem__

  • __init__: 对类进行初始化
  • __len__: 使该类可以返回dataset样本数量
  • __getitem__: 给定一个idx,从数据集中导入并返回一个样本

下面我们来看看该如何构建Custom Dataset:

import os
import pandas as pd
from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file) # load labelself.img_dir = img_dirself.transform = transform # transformationself.target_transform = target_transformdef __len__(self):return len(self.img_labels) # 返回sample的个数def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path) # load idx-th imagelabel = self.img_labels.iloc[idx, 1] # load idx-th labelif self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

注意:同时,__len__控制着产生样本的总个数。例如,如果总共有20个样本,我们希望20个样本全都放入dataloader中,则:

def __len(self):return 20

但如果我们只希望有20个样本中的15个放入到dataloader中,则:

def __len(self):return 15

但值得注意的是,return返回的数不能大于样本的总个数,即要小于等于20。并且,当返回的数小于总样本个数的时候,是取索引的前几个数,最后的几个数不会被放入dataloader中。比如return 15,是将data[:15]个数放入dataloader,而后5个数要舍去。可以用如下代码验证:

>>> data = np.arange(15).reshape(5,3)
>>> print(data)
array([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11],[12, 13, 14]])
>>> class Data(Dataset):
...		def __init__(self, data) -> None:
...			super(Data, self).__init__()
...			self.data = data
...		def __len__(self):
...			return 4
...		def __getitem__(self, index):
...			out = self.data[index]
...			return torch.from_numpy(out)
>>> loader = DataLoader(Data(data), batch_size=4, shuffle=True)
>>> for i, x in enumerate(loader):
...		print(i, x)0 tensor([[ 3, 4, 5],[ 9, 10, 11],[ 0, 1, 2],[ 6, 7, 8]])

可以发现,无论如何都不会输出[12, 13, 14]

Reference:
Pytorch official tutorial
Writing custom datasets dataloaders and transforms

http://www.lryc.cn/news/43762.html

相关文章:

  • 信息论小课堂:纠错码(海明码在信息传输编码时,通过巧妙的信道编码保证有了错误能够自动纠错。)
  • MySQL执行计划(explain)
  • 思必驰回复第二轮审核问询,如何与科大讯飞、阿里巴巴“虎口夺食”?
  • 基于Spring、SpringMVC、MyBatis的汽车租赁系统设计
  • 读《刻意练习》后感,与原文好句摘抄
  • 华为OD机试用java实现 -【选座位】
  • 国产蓝牙耳机怎么挑选?口碑最好的国产蓝牙耳机
  • seaborn从入门到精通03-绘图功能实现02-分类绘图Categorical plots
  • ❤️独特的算法❤️:一文解决编辑距离问题
  • 三次样条样条:Bézier样条和Hermite样条
  • Redis面试题 (2023最新版)
  • 基于springboot实现家乡特色食品景点推荐系统【源码+论文】分享
  • Spring MVC 启动之 HandlerMapping
  • 基于YOLOv5的停车位检测系统(清新UI+深度学习+训练数据集)
  • 【Linux系统编程】5.vim基本操作命令
  • 主流机器学习平台调研与对比分析
  • 作业帮基于明道云开展的硬件业务数字化建设
  • 位图及布隆过滤器的模拟实现与面试题
  • 在 Python 中将天数添加到日期
  • vue3知识点
  • 一行代码生成Tableau可视化图表
  • 链表——删除元素或插入元素(头插法及尾插法)
  • oracle容器的使用
  • 基于springboot会员制医疗预约服务管理信息系统演示【附项目源码】
  • GoogleAdsense国内加载慢怎么解决?
  • 【MySQL专题】03、性能优化之读写分离(MaxScale)
  • Redis7高级之BigKey(二)
  • flex弹性盒子
  • [Java Web]Cookie | 一文详细介绍会话跟踪技术中的Cookie
  • 这可能是2023最全的Java面试八股文,共计1658页,Java技术手册的天花板