当前位置: 首页 > news >正文

深度学习(4):数据加载器

一、Dataset:数据集类

1.数据集类需要继承Dataset类

2.实现__init__方法,数据初始化

3.实现__len__方法,返回数据集的长度

4.实现__getitem__方法,根据索引下标获取数据

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsclass MyDataset(Dataset):def __init__(self,data,labels):assert len(data) == len(labels)self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self,index):sample = self.data[index]label = self.labels[index]return sample,label

二、DataLoader:数据加载器

返回一个迭代器

参数:

dataset:要加载的数据集

batch_size:每批次读取的样本数量

shuffle:是否打乱顺序,True-打乱,False-不打乱

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsx = torch.randn(1000, 20)
y = torch.randn(1000, 10)dataset = MyDataset(x, y)
print(len(dataset))#1000
print(dataset[0])
"""
(tensor([ 0.1911, -0.0872,  0.4112, -0.3616, -2.4566, -0.5119,  0.1298,  1.0090,-0.6610, -1.3058,  0.1351, -1.6622,  0.8579, -0.5143,  0.6540, -0.0464,0.4354, -0.1966, -0.1209,  0.2876]), tensor([ 1.8922,  1.4897, -1.4169, -1.2283, -0.9311, -0.7850,  0.9580,  0.3025,0.3257, -0.3441]))
"""dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True
)for x, y in dataloader:
print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break

三、TensorDataset: torch提供的dataset类

如果对数据没有特殊处理的情况下,可以考虑使用TensorDataset

如果需要对数据进行特殊处理,可以考虑自定义Dataset数据集

    x = torch.randn(1000, 20)y = torch.randn(1000, 10)dataset = TensorDataset(x, y)dataloader = DataLoader(dataset=dataset,batch_size=100,shuffle=True)for x, y in dataloader:print(x.shape, y.shape)#torch.Size([100, 20]) torch.Size([100, 10])break

四、自定义图片加载器

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import pandas as pd
from torchvision import transforms, datasetsfilepath = './datasets/animals'transform = transforms.Compose([transforms.Resize(size=(224, 224)),transforms.ToTensor()
])dataset = datasets.ImageFolder(filepath, transform=transform)
dataloader = DataLoader(dataset=dataset,batch_size=20,shuffle=True
)for x, y in dataloader:print(x, y)break

五、加载MNIST数据集

# MNIST数据集:黑底白字的手写数字,图片分辨率:28*28
# 分训练数据集(60000)和测试数据集(10000)
def test05():transform = transforms.Compose([transforms.ToTensor()])# train: 是否为训练数据集# root:保存数据集的路径# transform:图片转换器train_dataset = datasets.MNIST(root='./datasets',train=True,download=True,transform=transform)dataloader = DataLoader(dataset=train_dataset,batch_size=20,shuffle=True)# 按批次遍历,每批次读取batch_size个数据for x, y in dataloader:print(x, y)break
"""
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],...,[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]],[[[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]]]) tensor([9, 2, 2, 0, 9, 1, 3, 7, 2, 5, 1, 8, 8, 8, 6, 2, 7, 6, 4, 2])
"""

http://www.lryc.cn/news/619905.html

相关文章:

  • go语言学习笔记
  • 初识神经网络05——构建神经网络3
  • C# 反射入门:如何获取 Type 对象?
  • 深度学习流体力学:基于PyTorch的物理信息神经网络(PINN)完整实现
  • Spring Boot项目通过Feign调用三方接口的详细教程
  • 力扣top100(day02-04)--二叉树 01
  • 阿里云Anolis OS 8.6的公有云仓库源配置步骤
  • 旧版MinIO的安装(windows)、Spring Boot 后端集成 MinIO 实现文件存储(超详细,带图文)
  • oss(阿里云)前端直传
  • 4G模块 ML307A通过MQTT协议连接到阿里云
  • ImportError: Encountered error: Failed to import NATTEN‘s CPP backend.
  • 事件处理与组件基础
  • 飞算JavaAI实现数据库交互:JPA/Hibernate + MyBatis Plus基础功能学习
  • 基于微信小程序的工作日报管理系统/基于asp.net的工作日报管理系统
  • CAD 的 C# 开发中,对多段线(封闭多边形)内部的点进行 “一笔连线且不交叉、不出界
  • 重生之我在公司写前端 | “博灵语音通知终端” | 登录页面
  • [量化交易](1获取加密货币的交易数据)
  • 01数据结构-Prim算法
  • Unity、C#常用的时间处理类
  • Gradle(三)创建一个 SpringBoot 项目
  • C++ 中构造函数参数对父对象的影响:父子控件管理机制解析
  • 【完整源码+数据集+部署教程】火柴实例分割系统源码和数据集:改进yolo11-rmt
  • 学习语言的一个阶段性总结
  • Linux操作系统应用编程——文件IO
  • Nginx的SSL通配符证书自动续期
  • 精准阻断内网渗透:联软科技终端接入方案如何“锁死”横向移动?
  • MySQL中的查询、索引与事务
  • MySQL三大存储引擎对比:InnoDB vs MyISAM vs MEMORY
  • RuoYi-Cloud 接入 Sentinel 的 3 种限流方式
  • Android 双屏异显技术全解析:从原理到实战的多屏交互方案