当前位置: 首页 > news >正文

【PyTorch】数据集

文章目录

  • 1. 创建数据集
    • 1.1. 直接继承Dataset类
    • 1.2. 使用TensorDataset类
  • 2. 数据集的划分
  • 3. 加载数据集
  • 4. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 数据集的划分

数据集可以划分为训练集、验证集和测试集。

  • 训练集:用于模型拟合的数据样本集合。
  • 验证集:通常被用来调整模型的参数,以找出效果最佳的模型。
  • 测试集:用于训练好的模型性能评估的数据样本集合。
from torch.utils.data import random_splittrain_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

3. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)

4. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()
http://www.lryc.cn/news/254145.html

相关文章:

  • oops-framework框架 之 本地存储(五)
  • 编程常见的问题
  • 针对Arrays.asList的坑,可以有哪些处理措施
  • SE考研真题总结(一)
  • Xshell远程登录AWS EC2 Linux实例
  • Elasticsearch:对时间序列数据流进行降采样(downsampling)
  • python自动化测试框架:unittest测试用例编写及执行
  • ctfhub技能树_web_web前置技能_HTTP
  • mysql8报sql_mode=only_full_group_by(存储过程一直报)
  • Vue2中v-html引发的安全问题
  • java内部类详解
  • Python 潮流周刊#29:Rust 会比 Python 慢?!
  • 吴恩达《机器学习》11-1-11-2:首先要做什么、误差分析
  • Pandas在Excel同一个sheet里插入多个Dataframe和行
  • 查看mysql 或SQL server 的连接数,mysql超时、最大连接数配置
  • C++学习之路(七)C++ 实现简单的Qt界面(消息弹框、按钮点击事件监听)- 示例代码拆分讲解
  • python实现一个计算器
  • C++ 共享内存ShellCode跨进程传输
  • 如何快速移植(从STM32F103到STM32F407)
  • python高级练习题库实验1(B)部分
  • Qt Rsa 加解密方法使用(pkcs1, pkcs8, 以及文件存储和内存存储密钥)
  • 区分物理端口与软件端口概念:以交换机端口和Linux系统中的端口为例
  • 力扣226:翻转二叉树
  • 亚马逊鲲鹏系统智能自动注册与AI角色养号,探索数字化新境界
  • AOP操作日志记录
  • Linux C语言 42-进程间通信IPC之网络通信(套接字)
  • 微服务知识大杂烩
  • 记录一次vscode markdown的图片路径相关插件学习配置过程
  • 设计原则 | 依赖转置原则
  • 前端开发实用技巧与经验分享