当前位置: 首页 > news >正文

图神经网络 pytorch GCN torch_geometric KarateClub 数据集

图神经网络

安装Pyg

首先安装torch_geometric需要安装pytorch然后查看一下自己电脑Pytorch的版本

import torch
print(torch.__version__)
#1.12.0+cu113

然后进入官网文档网站

链接: https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
安装自己的版本选择安装命令,我python用的稳定的3.8版本。如果安装失败可以考虑降低python的版本

在这里插入图片描述
因为我之前安装过所以显示如下
在这里插入图片描述

图信号数据集初入门

本次入门选用Karateclub数据集
在这里插入图片描述
这个数据集讲诉的是一个空手道俱乐部之间人和人的关系,每个节点代表一个人说俱乐部的两个教练吵架了,要每一个节点所代表的人进行站队通过图信号预测。
首先读取数据集

from torch_geometric.datasets import KarateClubdataset = KarateClub()
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')
#Number of graphs:1
#Number of features:34
#Number of classes:4

可以看到只有一张图每个节点有34个特征,每个特征代表的应该是每一个会员的信息,分成四类我们可以暂时理解成跟了教练A的,跟了教练B的,换了一个新教练的,和退出俱乐部的这四类。

然后我们将这个图打出来进行观察可以看到节点是分成了四类

import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nxdataset = KarateClub()
print(f'Dataset{dataset}')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')def visualize_graph(G,color):plt.figure(figsize=(7,7))plt.xticks([])plt.yticks([])nx.draw_networkx(G,pos = nx.spring_layout(G,seed = 42),with_labels=False,node_color=color,cmap="Set2")plt.savefig("net.jpg")plt.show()data = dataset[0]
print(data)
G = to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)

在这里插入图片描述
然后我们观察一个图的数据可以观察到一共有34个节点每个节点有34个数据一共有156条边

from torch_geometric.datasets import KarateClubdataset = KarateClub()
data = dataset[0]
print(data)
#Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

训练代码

接下来使用pyg进行训练

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
import matplotlib.pyplot as pltdataset = KarateClub()
data = dataset[0]def visualize_embedding(h,color,epoch=None,loss=None):global iplt.figure(figsize=(7,7))plt.xticks([])plt.yticks([])h = h.detach().cpu().numpy()plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")if epoch is not None and loss is not None:plt.xlabel(f"Epoch:{epoch},Loss: {loss.item():.4f}",fontsize = 16,)plt.show()class GCN(torch.nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features,4)self.conv2 = GCNConv(4,4)self.conv3 = GCNConv(4,2)self.classifier = Linear(2,dataset.num_classes)def forward(self,x,edge_index):h = self.conv1(x,edge_index)h = h.tanh()h = self.conv2(h,edge_index)h = h.tanh()h = self.conv3(h,edge_index)h = h.tanh()out = self.classifier(h)# return F.softmax(out,dim=1),hreturn out,hmodel = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)loss_list = []def train(data):optimizer.zero_grad()out, h = model(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()loss_list.append(loss.item())optimizer.step()return loss, hfor epoch in range(401):loss, h = train(data)if epoch % 10 == 1:visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)plt.plot(loss_list)
plt.show()

损失曲线如下
在这里插入图片描述

训练集可视化动图如下
![在这里插入图片描述](https://img-blog.csdnimg.cn/e1cff7bf8ba0498cb8abd6f61a2f83a6.gif在这里插入图片描述

http://www.lryc.cn/news/30587.html

相关文章:

  • 【博学谷学习记录】超强总结,用心分享丨人工智能 自然语言处理 文本特征处理小结
  • 2023年中职网络安全竞赛解析——隐藏信息探索
  • 实用操作--迁移到Spring Boot 3 和 Spring 6 需要关注的JAVA新特性
  • 等保检测风险处理方案
  • java 包装类 万字详解(通俗易懂)
  • 为什么我复制的中文url粘贴出来会是乱码的? 浏览器url编码和解码
  • 移动端适配
  • 【FPGA】Verilog:时序电路应用 | 序列发生器 | 序列检测器
  • Biomod2 (下):物种分布模型建模
  • Linux性能学习(2.2):内存_进程线程内存分配机制探究
  • BPMN2.0规范及流程引擎选型方案
  • VMware虚拟机安装Linux教程
  • 多人协作|RecyclerView列表模块新架构设计
  • SpringBoot (六) 整合配置文件 @Value、ConfigurationProperties
  • docker 入门篇
  • MapReduce的shuffle过程详解
  • 【软件使用】MarkText下载安装与汉化设置 (markdown快捷键收藏)
  • LeetCode笔记:Biweekly Contest 99
  • 初探富文本之CRDT协同实例
  • 团队死气沉沉?10种玩法激活你的项目团队拥有超强凝聚力
  • Spring三级缓存核心思想
  • 深度学习算法训练和部署流程介绍--让初学者一篇文章彻底理解算法训练和部署流程
  • 计算机网络整理
  • 闲人闲谈PS之三十八——混合制生产下WBS-BOM价格发布增强
  • Java 根类 Object
  • 04_Apache Pulsar的可视化监控管理、Apache Pulsar的可视化监控部署
  • 【算法】期末复盘,酒店住宿问题——勿向思想僵化前进
  • Java中的Comparator 与 Comparable详解
  • 计算机科学导论笔记(二)
  • GEC6818开发板JPG图像显示,科大讯飞离线语音识别包Linux_aitalk_exp1227_1398d7c6运行demo程序,开发板实现录音