当前位置: 首页 > news >正文

PyTorch Geometric基本教程

PyG官方文档


# Install torch geometric
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.2+cu102.html
!pip install -q torch-geometricimport torch
import networkx as nx
import matplotlib.pyplot as plt

1.内置数据集(以KarateClub为例)

from torch_geometric.datasets import KarateClubdataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
# 图的数量
print(f'Number of graphs: {len(dataset)}')
# 每个节点的特征尺寸
print(f'Number of features: {dataset.num_features}')
# 节点的类别数量
print(f'Number of classes: {dataset.num_classes}')
# 获取具体的图
data = dataset[0]
print(data)
print('==============================================================')# 获取图的属性
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {(2*data.num_edges) / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.has_isolated_nodes()}')
print(f'Contains self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
# 取出的图的数据对象为Data类型,包含以下属性
# 1. edge_index 每条边的两个端点的索引组成的元组
# 2. x 节点特征[节点数量,特征维数]
# 3. y 节点标签(类别),每个节点只分配一个类别
# 4. train_mask 
Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])
print(data)
# 获取所有的边
print(data.edge_idx.T)

2.可视化

def visualize(h, color, epoch=None, loss=None, accuracy=None):plt.figure(figsize=(7,7))plt.xticks([])plt.yticks([])if torch.is_tensor(h):h = h.detach().cpu().numpy()plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")if epoch is not None and loss is not None and accuracy['train'] is not None and accuracy['val'] is not None:plt.xlabel((f'Epoch: {epoch}, Loss: {loss.item():.4f} \n'f'Training Accuracy: {accuracy["train"]*100:.2f}% \n'f' Validation Accuracy: {accuracy["val"]*100:.2f}%'),fontsize=16)else:# networkx的draw_networkxnx.draw_networkx(h, pos=nx.spring_layout(h, seed=42), with_labels=False, node_color=color, cmap="Set2")   plt.show()
from torch_geometric.utils import to_networkx
# 将Data类型转换成networkx
G = to_networkx(data, to_undirected=True)
# 将图可视化,节点颜色为节点的类型
visualize(G, color=data.y)

3.搭建GNN(以GCN为例)

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self):super().__init__()self.conv1 = GCNConv(dataset.num_features, 4)self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index)h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()out = self.classifier(h)return out, hmodel = GCN()
print(model)
# 节点分类
model = GCN()out, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')visualize(h, color=data.y)

4.在KarateClub数据集上训练

import time
model = GCN()# 交叉熵损失,Adam优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())def train(data):optimizer.zero_grad()out, h  = model(data.x, data.edge_index)# 只对train_mask的节点计算lossloss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()accuracy = {}# torch.argmax 取置信度最大的一类predicted_classes = torch.argmax(out[data.train_mask], axis=1) # [0.6, 0.2, 0.7, 0.1] -> 2target_classes = data.y[data.train_mask]accuracy['train'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())predicted_classes = torch.argmax(out, axis=1)target_classes = data.yaccuracy['val'] = torch.mean(torch.where(predicted_classes == target_classes, 1, 0).float())return loss, h, accuracy
for epoch in range(500):loss, h, accuracy = train(data)if epoch % 10 == 0:visualize(h, color=data.y, epoch=epoch, loss=loss, accuracy=accuracy)time.sleep(0.3)
http://www.lryc.cn/news/130976.html

相关文章:

  • MAC 命令行启动tomcat的详细介绍
  • idea2023 springboot2.7.5+mybatisplus3.5.2+jsp 初学单表增删改查
  • 轻松搭建书店小程序
  • Spark MLlib机器学习库(一)决策树和随机森林案例详解
  • CI/CD入门(二)
  • 【BASH】回顾与知识点梳理(三十五)
  • excel逻辑函数篇2
  • 设计模式详解-解释器模式
  • 如何在React项目中动态插入HTML内容
  • 十六、Spring Cloud Sleuth 分布式请求链路追踪
  • ElasticSearch DSL语句(bool查询、算分控制、地理查询、排序、分页、高亮等)
  • 【考研数学】概率论与数理统计 | 第一章——随机事件与概率(2,概率基本公式与事件独立)
  • SpringBoot整合RabbitMQ,笔记整理
  • 搜狗拼音暂用了VSCode及微信小程序开发者工具快捷键Ctrl + Shit + K 搜狗拼音截图快捷键
  • Python包sklearn画ROC曲线和PR曲线
  • snpEff变异注释的一点感想
  • “保姆级”考研下半年备考时间表
  • 具有弱监督学习的精确3D人脸重建:从单幅图像到图像集的Python实现详解
  • 查询投稿会议的好用网址
  • 一元三次方程的解
  • aardio开发语言Excel数据表读取修改保存实例练习
  • webshell绕过
  • Spring Boot 统一功能处理
  • 图像处理常见的两种拉流方式
  • 数据可视化数据调用浅析
  • 恒运资本:CPO概念发力走高,兆龙互联涨超10%,华是科技再创新高
  • 【蓝桥杯】[递归]母牛的故事
  • 使用RDP可视化远程桌面连接Linux系统
  • 数据可视化diff工具jsondiffpatch使用学习
  • pdf 转 word