当前位置: 首页 > news >正文

Dataset和DataLoader用法

Dataset和DataLoader用法

在d2l中有简洁的加载固定数据的方式,如下

d2l.load_data_fashion_mnist()
# 源码
Signature: d2l.load_data_fashion_mnist(batch_size, resize=None)
Source:   
def load_data_fashion_mnist(batch_size, resize=None):"""Download the Fashion-MNIST dataset and then load it into memory.Defined in :numref:`sec_fashion_mnist`"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
File:      ~/anaconda3/envs/d2l/lib/python3.9/site-packages/d2l/torch.py
Type:      function

如果我们要自定义需要加载的数据集

数据集:一个图片文件夹,用csv文件来表示训练数据和标签

# 定义Dataset
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoaderfrom sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transformsclass CustomDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):self.data = pd.read_csv(csv_file) self.root_dir = root_dirself.transform = transformlabel_encoder = LabelEncoder()self.labels = label_encoder.fit_transform(self.data['label'])def __len__(self):return len(self.data)def __getitem__(self, idx):img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])# 读取图片并做增广image = Image.open(img_name)if self.transform is not None:image = self.transform(image)# 将数字转换成独热编码的张量(记得转换成float)label = F.one_hot(torch.tensor(self.labels[idx]), 		num_classes=self.data['label'].nunique()).float()return image, label# 定义参数和超参数训练
batch_size = 256
lr = num_epoch = 0.9, 10# 加载数据
sample = '/kaggle/input/classify-leaves/sample_submission.csv'
ts_path = "/kaggle/input/classify-leaves/test.csv"
tr_path = "/kaggle/input/classify-leaves/train.csv"
image_path = '/kaggle/input/classify-leaves'dataset = CustomDataset(csv_file = sample, root_dir = image_path, transform=transform_train)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
tr_dataset, te_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])tr_dataloader = DataLoader(tr_dataset, batch_size, shuffle=True)
ts_dataloader = DataLoader(te_dataset, batch_size, shuffle=False)

总结

需要将__init__,len,__getitem__按照数据集和模型的要求,对应的编写好代码。

http://www.lryc.cn/news/174798.html

相关文章:

  • 【跟小嘉学习区块链】二、Hyperledger Fabric 架构详解
  • springboot下spring方式实现Websocket并设置session时间
  • LeetCode算法二叉树—相同的树
  • 搭建Flink集群、集群HA高可用以及配置历史服务器
  • vscode终端中打不开conda虚拟包管理
  • 【音视频】MP4封装格式
  • 环境-使用vagrant快速创建linux虚拟机
  • 10.1网站编写(Tomcat和servlet基础)
  • 10CQRS
  • DAZ To UMA⭐一.DAZ简单使用教程
  • 面试题 —— Java集合篇(23题)
  • SpringBoot2.7.14整合Swagger3.0的详细步骤及容易踩坑的地方
  • 题解:ABC321D - Set Menu
  • 什么是Progressive Web App(PWA)?它们有哪些特点?
  • MySQL的高级SQL语句
  • 基于人脸5个关键点的人脸对齐(人脸纠正)
  • vue3中两个el-select下拉框选项相互影响
  • 博弈论——反应函数
  • UE5读取json文件
  • Vue中的插槽--组件复用,内容自定义
  • 完全指南:mv命令用法、示例和注意事项 | Linux文件移动与重命名
  • gitee生成公钥和远程仓库与本地仓库使用验证
  • 请求后端接口413
  • HarmonyOS之 开发环境搭建
  • QTC++ day12
  • Vue3中使用Proxy API取代defineProperty API的原因
  • 构建工具Webpack简介
  • Docker部署单点Elasticsearch与Kibana
  • opencv实现仿射变换和透射变换
  • 抖音seo账号矩阵源码系统