深度学习(4):数据加载器
一、Dataset:数据集类
1.数据集类需要继承Dataset类
2.实现__init__方法,数据初始化
3.实现__len__方法,返回数据集的长度
4.实现__getitem__方法,根据索引下标获取数据
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsclass MyDataset(Dataset):def __init__(self,data,labels):assert len(data) == len(labels)self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self,index):sample = self.data[index]label = self.labels[index]return sample,label
二、DataLoader:数据加载器
返回一个迭代器
参数:
dataset:要加载的数据集
batch_size:每批次读取的样本数量
shuffle:是否打乱顺序,True-打乱,False-不打乱
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsx = torch.randn(1000, 20)
y = torch.randn(1000, 10)dataset = MyDataset(x, y)
print(len(dataset))#1000
print(dataset[0])
"""
(tensor([ 0.1911, -0.0872, 0.4112, -0.3616, -2.4566, -0.5119, 0.1298, 1.0090,-0.6610, -1.3058, 0.1351, -1.6622, 0.8579, -0.5143, 0.6540, -0.0464,0.4354, -0.1966, -0.1209, 0.2876]), tensor([ 1.8922, 1.4897, -1.4169, -1.2283, -0.9311, -0.7850, 0.9580, 0.3025,0.3257, -0.3441]))
"""dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True
)for x, y in dataloader:
print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break
三、TensorDataset: torch提供的dataset类
如果对数据没有特殊处理的情况下,可以考虑使用TensorDataset
如果需要对数据进行特殊处理,可以考虑自定义Dataset数据集
x = torch.randn(1000, 20)y = torch.randn(1000, 10)dataset = TensorDataset(x, y)dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True)for x, y in dataloader:print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break
四、自定义图片加载器
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsfilepath = './datasets/animals'transform = transforms.Compose([transforms.Resize(size=(224, 224)),transforms.ToTensor()
])dataset = datasets.ImageFolder(filepath, transform=transform)
dataloader = DataLoader(dataset=dataset,batch_size=20,shuffle=True
)for x, y in dataloader:print(x, y)break
五、加载MNIST数据集
# MNIST数据集:黑底白字的手写数字,图片分辨率:28*28
# 分训练数据集(60000)和测试数据集(10000)
def test05():transform = transforms.Compose([transforms.ToTensor()])# train: 是否为训练数据集# root:保存数据集的路径# transform:图片转换器train_dataset = datasets.MNIST(root='./datasets',train=True,download=True,transform=transform)dataloader = DataLoader(dataset=train_dataset,batch_size=20,shuffle=True)# 按批次遍历,每批次读取batch_size个数据for x, y in dataloader:print(x, y)break
"""
tensor([[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],...,[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]],[[[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],...,[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.],[0., 0., 0., ..., 0., 0., 0.]]]]) tensor([9, 2, 2, 0, 9, 1, 3, 7, 2, 5, 1, 8, 8, 8, 6, 2, 7, 6, 4, 2])
"""