当前位置: 首页 > news >正文

Pytorch建立MyDataLoader过程详解

简介

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

详细:DataLoader

自己基于DataLoader实现各个模块

代码实现

MyDataset基于torch中的Data实现对个人数据集的载入,例如图像和标签载入
SingleSampler基于torch中的Sampler实现对于数据的batch个数图像的载入,例如,Batch_Size=4,实现对所有数据中选取4个索引作为一组,然后在MyDataset中基于__getitem__根据图像索引去进行图像操作
MyBathcSampler基于torch的BatchSampler实现自己对于batch_size数据的处理。需要基于SingleSampler实现Sampler的处理,更为灵活。MyBatchSampler的存在会自动覆盖DataLoader中的batch_size参数
注:Sampler的实现,将会与shuffer冲突,shuffer是在没有实现sampler前提下去自动判断选择的sampler类型
collate_fn是实现将batch_size的图像数据进行打包,遍历过程中就可以实现batch_size的images和labels对应
在这里插入图片描述

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Samplerclass MyDataset(Dataset):def __init__(self) -> None:self.data = torch.arange(20)def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]@staticmethoddef collate_fn(batch):return torch.stack(batch, 0)class MyBatchSampler(BatchSampler):def __init__(self, sampler: Sampler[int], batch_size: int) -> None:self._sampler = samplerself._batch_size = batch_sizedef __iter__(self) -> Iterator[List[int]]:batch = []for idx in self._sampler:batch.append(idx)if len(batch) == self._batch_size:yield batchbatch = []yield batchdef __len__(self):return len(self._sampler) // self._batch_sizeclass SingleSampler(Sampler):def __init__(self, data_source) -> None:self._data = data_sourceself.num_samples = len(self._data)def __iter__(self):# 顺序采样# indices = range(len(self._data))# 随机采样indices = torch.randperm(self.num_samples).tolist()return iter(indices)def __len__(self):return self.num_samplestrain_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:print(data)

参考

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

http://www.lryc.cn/news/135738.html

相关文章:

  • 十问华为云 Toolkit:开发插件如何提升云上开发效能
  • NO.06 自定义映射resultMap
  • 国产精品:讯飞星火最新大模型V2.0
  • 网络综合布线实训室方案(2023版)
  • Qt应用开发(基础篇)——文本编辑窗口 QTextEdit
  • NineData中标移动云数据库传输项目(2023)
  • Java面向对象三大特性之多态及综合练习
  • HTTPS 握手过程
  • docker之Consul环境的部署
  • 服务机器人,正走向星辰大海
  • SciencePub学术 | 计算机及交叉类重点SCIE征稿中
  • Java面试题--SpringCloud篇
  • 【linux】常用的互斥手段及实例简述
  • STM32 F103C8T6学习笔记12:红外遥控—红外解码-位带操作
  • linux 环境收集core文件步骤
  • Git企业开发控制理论和实操-从入门到深入(一)|为什么需要Git|Git的安装
  • 上篇——税收大数据应用研究
  • 疲劳驾驶检测和识别4:C++实现疲劳驾驶检测和识别(含源码,可实时检测)
  • Android WakefulBroadcastReceiver的使用
  • python知识:什么是字符编码?
  • Vue2中使用Pinia
  • Docker关于下载,镜像配置,容器启动,停止,查看等基础操作
  • 穿越网络迷雾的神奇通道 - WebSocket详解
  • 无脑入门pytorch系列(五)—— nn.Dropout
  • Python土力学与基础工程计算.PDF-压水试验
  • Linux入门
  • 适合国内用户的五款ChatGPT插件
  • Dubbo Spring Boot Starter 开发微服务应用
  • linux中互斥锁,自旋锁,条件变量,信号量,与freeRTOS中的消息队列,信号量,互斥量,事件的区别
  • 安装docker服务,配置镜像加速器