当前位置: 首页 > news >正文

Pytorch使用Dataset加载数据

1、前言:

在阅读之前,需要配置好对应pytorch版本。
对于一般学习,使用cpu版本的即可。参考教程点我
导入pytorch包,使用如下命令即可。

import torch   # 注意虽然叫pytorch,但是在引用时是引用torch

2、神经网络获取数据

神经网络获取数据主要用到Dataset和Dataloader两个方法
Dataset主要用于获取数据以及对应的真实label
Dataloader主要为后面的网络提供不同的数据形式
在torch.utils.data包内提供了DataSet类,可在Pytorch官网看到对应的描述

class Dataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`.All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite :meth:`__getitem__`, supporting fetching adata sample for a given key. Subclasses could also optionally overwrite:meth:`__len__`, which is expected to return the size of the dataset by many:class:`~torch.utils.data.Sampler` implementations and the default optionsof :class:`~torch.utils.data.DataLoader`. Subclasses could alsooptionally implement :meth:`__getitems__`, for speedup batched samplesloading. This method accepts list of indices of samples of batch and returnslist of samples... note:::class:`~torch.utils.data.DataLoader` by default constructs an indexsampler that yields integral indices.  To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided."""def __getitem__(self, index) -> T_co:raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")# def __getitems__(self, indices: List) -> List[T_co]:# Not implemented to prevent false-positives in fetcher check in# torch.utils.data._utils.fetch._MapDatasetFetcherdef __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":return ConcatDataset([self, other])# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.py

根据上述描述可知,Dataset是一个抽象类,用于表示数据集。你可以通过继承这个类并实现以下方法来自定义数据集:

__len__(self): 返回数据集的大小,即数据集中有多少个样本。
__getitem__(self, idx): 根据索引 idx 返回数据集中的一个样本和对应的标签。

3、案例

使用Dataset读取文件夹E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train\ants下所有图片。并获取对应的label,该数据集的文件夹的名字为对应的标签,而文件夹内为对应的训练集的图片

import os
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsclass MyDataset(Dataset):def __init__(self, root_path, label):self.root_path = root_pathself.label = labelself.img_path = os.path.join(root_path, label)  # 拼接路径print(f"图片路径: {self.img_path}")  # 打印路径以进行调试try:self.img_path_list = os.listdir(self.img_path)  # 列出文件夹中的文件print(f"图片列表: {self.img_path_list}")  # 打印图片列表以进行调试except PermissionError as e:print(f"权限错误: {e}")except FileNotFoundError as e:print(f"文件未找到错误: {e}")def __getitem__(self, index):img_index = self.img_path_list[index]img_path = os.path.join(self.img_path, img_index)try:img = Image.open(img_path)except Exception as e:print(f"读取图片时出错: {e}, 图片路径: {img_path}")raise elabel = self.labelreturn img, labeldef __len__(self):return len(self.img_path_list)# 实例化这个类
my_data = MyDataset(root_path=r'E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train', label='ants')
writer = SummaryWriter('logs')
for i in range(my_data.__len__()):img, label = my_data[i]  # 依次获取对应的图片# 此处img为PIL Image, 使用transforms中的ToTensor方法转化为tensor格式writer.add_image(tag=label, img_tensor=transforms.ToTensor()(img), global_step=i)
writer.close()
print(f"当前文件夹下{i + 1}张图片已读取完毕,请在Tensorboard中查看")

在这里插入图片描述
在控制台输入tensorboard --logdir='E:\Python_learning\Deep_learning\note\logs'打开tensorboard查看
在这里插入图片描述
在这里插入图片描述

http://www.lryc.cn/news/400243.html

相关文章:

  • 【nginx】nginx的优点
  • K8S ingress 初体验 - ingress-ngnix 的安装与使用
  • qt 获取父控件
  • flask基础配置详情
  • 单相整流-TI视频课笔记
  • 用GPT 4o提高效率
  • 20240711每日消息队列-------------MQ消息的积压的折磨
  • 推荐一个比 Jenkins 使用更简单的项目构建和部署工具
  • java 在pdf中根据关键字位置插入图片(公章、签名等)
  • 施耐德EOCR系列电机保护器全面升级后無端子型
  • 27.数码管的驱动,使用74HC595移位寄存器芯片
  • TCP/IP 原理、实现方式与优缺点
  • 利率债与信用债的区别及其与债券型基金的关系
  • linux下解压命令
  • Vulnhub靶场DC-3-2练习
  • Swift入门笔记
  • 【提交ACM出版 | EIScopus检索稳定 | 高录用】第五届大数据与社会科学国际学术会议(ICBDSS 2024,8月16-18)
  • Postman与世界相连:集成第三方服务的全面指南
  • Perl 语言开发(十四):数据库操作
  • Qt+ESP32+SQLite 智能大棚
  • Android Viewpager2 remove fragmen不生效解决方案
  • 桃园南路上的红绿灯c++
  • 有关去中心化算路大模型的一些误区:低带宽互连导致训练速度太慢;小容量设备无法生成基础规模的模型;去中心化总是会花费更多;虫群永远不够大
  • uni-app iOS上架相关App store App store connect 云打包有次数限制
  • python单测框架之pytest常见用法
  • [终端安全]-8 隐私保护和隐私计算技术
  • MySQL 日志深度解析:从查询执行到性能优化
  • sql server 练习题5
  • ai伪原创生成器app,一键伪原创文章效率高
  • 【ZhangQian AI模型部署】目标检测、SAM、3D目标检测、旋转目标检测、人脸检测、检测分割、关键点、分割、深度估计、车牌识别、车道线识别