当前位置: 首页 > news >正文

【Pytorch】深度学习之数据读取

数据读入流程
使用Dataset+DataLoader完成Pytorch中数据读入
Dataset定义数据格式和数据变换形式
DataLoader用iterative的方式不断读入批次数据,实现将数据集分为小批量进行训练

使用PyTorch自带数据集
使用Dataset完成数据格式和数据变换的定义

import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)

参数说明:
transform实现对图像数据的变换处理

使用DataLoader完成按批次读取数据

from torch.utils.data import DataLoadertrain_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)

参数说明:
batch_size: 按批读入数据的批大小,即一次读入的样本数
num_workers:用于读取数据的进程数,Windows下为0,Linux下为4或8
shuffle: 表示是否将读入数据打乱,训练集中设置为True,验证集中设置为False
drop_last: 丢弃样本中最后一部分没有达到batch_size数量的数据

数据展示

import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
# 使用transpose()函数改变原始图像的表示形式,从(H,W,C)的表示转换为(C,H,W)的表示
plt.imshow(images[0].transpose(1,2,0)) 
plt.show()

自定义数据集方式

  1. 自定义Dataset类继承Dataset
  2. 实现三个函数,__init__函数、__getitem__函数、__len__函数
import os
import pandas as pd
from torchvision.io import read_imageclass MyDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):"""Args:annotations_file (string): Path to the csv file with annotations.img_dir (string): Directory with all the images.transform (callable, optional): Optional transform to be applied on a sample.target_transform (callable, optional): Optional transform to be applied on the target."""self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):"""Args:idx (int): Index"""# 使用path.join()函数构建图像路径,img_labels.iloc[行,列]用于通过行列索引访问DataFrame中的元素img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label
http://www.lryc.cn/news/188893.html

相关文章:

  • Maven教程
  • 一篇带你看懂异步:promise、async await
  • RocketMQ快速实战以及集群架构详解
  • 京东运营数据分析:2023年8月京东饮料行业品牌销售排行榜
  • ES6之函数的扩展二
  • Ubuntu-Ports更新源 ARM64更新源
  • 渗透测试怎么入门?(超详细解读)
  • MS31804四通道低边驱动器可pin对pin兼容DRV8804
  • Fastadmin 子级菜单展开合并,分类父级归纳
  • Idea创建springboot工程的时候,发现pom文件没有带<parent>标签
  • element树形控件编辑节点组装节点
  • 【算法-动态规划】斐波那契第 n 项
  • Linux系统运行级别详解,切换、配置和常见服务
  • 企业需要ERP系统的八大理由,最后一个尤其重要
  • Java-Atomic原子操作类详解及源码分析,Java原子操作类进阶,LongAdder源码分析
  • 算法通过村第十二关-字符串|黄金笔记|冲刺难题
  • 3ds Max渲染太慢?创意云“一键云渲染”提升3ds Max渲染体验
  • 记录一次公益SRC的常见的cookie注入漏洞(适合初学者)
  • [ACTF2020 新生赛]Exec1
  • DeepFace【部署 03】轻量级人脸识别和面部属性分析框架deepface在Linux环境下服务部署(conda虚拟环境+docker)
  • vuex的求和案例和mapstate,mapmutations,mapgetters
  • Docker 网络访问原理解密
  • 统信UOS离线安装nginx
  • 机器学习基础-手写数字识别
  • idea 插件推荐(持续更新)
  • 实现Promise所有核心功能和方法
  • 学习总结1
  • 使用 Apache Camel 和 Quarkus 的微服务(二)
  • pid-limit参数实验
  • jvm--执行引擎