当前位置: 首页 > news >正文

【PYG】使用datalist定义数据集,创建一个包含多个Data对象的列表并使用DataLoader来加载这些数据

为了使用你提到的封装方式来创建一个包含多个 Data 对象的列表并使用 DataLoader 来加载这些数据,我们可以按照以下步骤进行:

  1. 创建数据:生成节点特征矩阵、边索引矩阵和标签。
  2. 封装数据:使用 Data 对象将这些数据封装起来。
  3. 使用 DataLoader:确保批次数据的形状符合期望。

具体步骤

1. 创建数据

首先,我们创建节点特征矩阵、边索引矩阵和标签数据。

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader  # 更新导入路径# 参数设置
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数# 生成数据
features = [torch.randn((num_nodes, num_node_features)) for _ in range(num_samples)]
labels = [torch.randn((num_nodes, 1)) for _ in range(num_samples)]
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1
print(adj_matrix)
2. 封装数据

使用 Data 对象将每个样本的数据封装起来。

data_list = [Data(x=features[i], adj=adj_matrix, y=labels[i]) for i in range(num_samples)]
3. 使用 DataLoader
# 创建 DataLoader
loader = DenseDataLoader(data_list, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)break  # 仅查看第一个批次的形状

总结

  1. 生成数据:我们生成了包含节点特征、边索引和标签的样本数据。
  2. 封装数据:我们使用 Data 对象将每个样本的数据封装起来。

完整代码

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader  # 更新导入路径# 参数设置
num_samples = 100  # 样本数
num_nodes = 10  # 每个图中的节点数
num_node_features = 8  # 每个节点的特征数# 生成数据
features = [torch.randn((num_nodes, num_node_features)) for _ in range(num_samples)]
labels = [torch.randn((num_nodes, 1)) for _ in range(num_samples)]
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):adj_matrix[i, (i + 1) % num_nodes] = 1adj_matrix[(i + 1) % num_nodes, i] = 1
print(adj_matrix)data_list = [Data(x=features[i], adj=adj_matrix, y=labels[i]) for i in range(num_samples)]# 创建 DataLoader
loader = DenseDataLoader(data_list, batch_size=32, shuffle=True)# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:print("Batch node features shape:", data.x.shape)  # 期望输出形状为 (32, 10, 8)print("Batch adjacency matrix shape:", data.adj.shape)  # 期望输出形状为 (32, 10, 10)print("Batch labels shape:", data.y.shape)  # 期望输出形状为 (32, 10, 1)break  # 仅查看第一个批次的形状
http://www.lryc.cn/news/387703.html

相关文章:

  • 【设计模式】【创建型5-2】【工厂方法模式】
  • python API自动化(Pytest+Excel+Allure完整框架集成+yaml入门+大量响应报文处理及加解密、签名处理)
  • 【Postman学习】
  • 【Linux】IO多路复用——select,poll,epoll的概念和使用,三种模型的特点和优缺点,epoll的工作模式
  • IBCS 虚拟专线——让企业用于独立IP
  • 驾驭巨龙:Perl中大型文本文件的处理艺术
  • Kafka~特殊技术细节设计:分区机制、重平衡机制、Leader选举机制、高水位HW机制
  • springcloud-config 客户端启用服务发现client的情况下使用metadata中的username和password
  • 云计算 | 期末梳理(中)
  • pytest测试框架pytest-order插件自定义用例执行顺序
  • 吴恩达机器学习 第三课 week2 推荐算法(上)
  • MySQL CASE 表达式
  • Unity3D 游戏数据本地化存储与管理详解
  • 昇思25天学习打卡营第1天|初学教程
  • ctfshow-web入门-命令执行(web59-web65)
  • Websocket在Java中的实践——最小可行案例
  • python请求报错::requests.exceptions.ProxyError: HTTPSConnectionPool
  • 【Unity】Excel配置工具
  • 001 线性查找(lua)
  • 数据结构之链表
  • 【小工具】 Unity相机宽度适配
  • centos误删yum和python
  • WP黑格导航主题BlackCandy
  • elasticsearch底层核心组件
  • EasyExcel数据导入
  • 20240630 每日AI必读资讯
  • 第十一章 Qt的模型视图
  • 力扣 单词规律
  • 10款好用不火的PC软件,真的超好用!
  • Windows怎么实现虚拟IP