当前位置: 首页 > news >正文

PyTorch 之 Dataset 类入门学习

PyTorch 之 Dataset 类入门学习

Dataset 类简介

  • PyTorch 中的 Dataset 类是一个抽象类,用来表示数据集。通过继承 Dataset 类可以进行自定义数据集的格式、大小和其它属性,供后续使用;
    在这里插入图片描述

  • 可以看到官方封装好的数据集也是直接或间接的继承自 Dataset
    在这里插入图片描述

自定义数据集逻辑

  • 继承 Dataset 类;
  • 重写 init():构造函数,可自定义数据读取方法以及进行数据预处理;
  • 重写 len():返回数据集大小;
  • 重写 getitem_():索引数据集中的某一个数据

代码实现

import torch
from torch.utils.data import Dataset# 自定义数据集继承 pytorch 内置的 Dataset 类class GreenDataset(Dataset):"""重写构造函数Args:data_tensor 数据或数据集合target_tensor 数据标签或数据标签集合"""def __init__(self, data_tensor, target_tensor):self.data_tensor = data_tensorself.target_tensor = target_tensor# 重写 len 方法: return 数据集大小def __len__(self):return self.data_tensor.size(0)# 重写 getitem 方法:基于索引,return 对应的数据及其标签,组合成 1 个元组返回def __getitem__(self, index):return self.data_tensor[index], self.target_tensor[index]def test_data_set():"""自定义数据集测试"""# 生成数据集和标签集 (数据元素长度=标签元素长度)# 10 行 3 列数据,可以理解为 10 个元素,每个元素是一维的 3个元素列表data_tensor = torch.randn(10, 3)# 对应方法 torch.randint(low, high, size)标签是 0 或 1 的 10 个元素# low ( int , optional ) – 要从分布中提取的最小整数。默认值:0# high ( int ) – 高于要从分布中提取的最高整数# size ( tuple ) – 定义输出张量形状的元组# 以下示例中 low 取默认值 0target_tensor = torch.randint(2, (10,))# 将数据封装成自定义数据集的 Datasetmy_dataset = GreenDataset(data_tensor, target_tensor)# 调用方法:查看数据集大小print('dataset size info:', len(my_dataset))# 根据索引获取数据print('tensor_data[0]: ', my_dataset[0])# 打印数据集for i, my_dataset in enumerate(my_dataset):print('索引值:%s 数据:%s' % (i, my_dataset))if __name__ == '__main__':test_data_set()

重点函数

  • torch.randn()
    在这里插入图片描述

  • torch.randint()

执行结果

在这里插入图片描述

http://www.lryc.cn/news/238964.html

相关文章:

  • Java update scheduler
  • 常见树种(贵州省):006栎类
  • 拓扑排序-
  • Oracle数据库如何定位trace file位置
  • 电脑盘符错乱,C盘变成D盘怎么办?
  • Android WMS——客户端输入事件处理(十九)
  • Python基础学习__测试报告
  • bclinux aarch64 ceph 14.2.10 云主机 4节点 fio
  • 智能座舱架构与芯片- (14) 测试篇 上
  • 【Django-DRF用法】多年积累md笔记,第3篇:Django-DRF的序列化和反序列化详解
  • Redis主从复制,哨兵和Cluster集群
  • Linux嵌入式I2C协议笔记
  • 科技的成就(五十三)
  • Ubuntu22.04 编译 AOSP
  • 【计算机网络】多路复用的三种方案
  • 供应链和物流的自动化新时代
  • Python与ArcGIS系列(九)自定义python地理处理工具
  • Nginx部署前端项目
  • 根据文件类型进行下载, 文档/图片
  • 赋范线性空间3
  • XSLVGL2.0 User Manual 缩略图生成器(v2.0)
  • 练习八-利用有限状态机进行时序逻辑的设计
  • WebAssembly照亮了 Web端软件的未来
  • PDF文件无密码,如何解密?
  • 搜维尔科技:Movella Xsens MVN LINK 实际应用,一镜到底!
  • wsl安装ubuntu的问题点、处理及连接
  • Flutter在web项目中使用iframe
  • 阿里云高校计划学生和教师完成认证领取优惠权益
  • 劲松HPV防治诊疗中心提醒:做完HPV检查后,需留意这些事项!
  • InfoNCE Loss公式及源码理解