当前位置: 首页 > news >正文

【PyTorch】(二)加载数据集

文章目录

  • 1. 创建数据集
    • 1.1. 直接继承Dataset类
    • 1.2. 使用TensorDataset类
  • 2. 加载数据集
  • 3. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)

3. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()
http://www.lryc.cn/news/249328.html

相关文章:

  • 如何提高3D建模技能?
  • 【前端开发】Next.js与Nest.js之间的差异2023
  • 【CAN通信】CanIf模块详细介绍
  • PS最新磨皮软件Portraiture4.1.2
  • 旋转框(obb)目标检测计算iou的方法
  • render函数举例
  • 微信小程序文件预览和下载-文件系统
  • 图解Redis适用场景
  • 掌握Python BentoML:构建、部署和管理机器学习模型
  • 西南科技大学模拟电子技术实验二(二极管特性测试及其应用电路)预习报告
  • 熟悉SVN基本操作-(SVN相关介绍使用以及冲突解决)
  • 代码随想录二刷 |字符串 |反转字符串II
  • 哪吒汽车拔头筹,造车新势力首家泰国工厂投产
  • Redis String类型
  • lxd提权
  • Ubuntu+Tesla V100环境配置
  • leetcode:用栈实现队列(先进先出)
  • <JavaEE> 什么是进程控制块(PCB Process Control Block)?
  • 简历上的工作经历怎么写
  • 数值分析总结
  • osg demo汇总
  • Leetcode.1590 使数组和能被 P 整除
  • uniappios请求打开麦克风 uniapp发起请求
  • Java 注解在 Android 中的使用场景
  • 【开源】基于Vue和SpringBoot的数字化社区网格管理系统
  • Go语言简要介绍
  • STM32H7 RTC及PC13问题
  • AntDB“超融合+流式实时数仓”——颠覆50年未变的数据库内核
  • TZOJ 1376 母牛的故事(递推和递归)
  • 五种多目标优化算法(MOPSO、MOAHA、NSGA2、NSGA3、MOGWO)求解微电网多目标优化调度(MATLAB)