当前位置: 首页 > news >正文

【GNN 03】PyG

工具包安装: 不要pip安装

https://github.com/pyg-team/pytorch_geometricicon-default.png?t=N7T8https://github.com/pyg-team/pytorch_geometric

 

import torch
import networkx as nx
import matplotlib.pyplot as pltdef visualize_graph(G, color):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False, node_color=color, cmap="Set2")plt.show()def visualize_embedding(h, color, epoch=None, loss=None):plt.figure(figsize=(7, 7))plt.xticks([])plt.yticks([])h = h.detach().cpu().numpy()plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")if epoch is not None and loss is not None:plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)plt.show()

1 dataset

from torch_geometric.datasets import KarateClubdataset = KarateClub()
print(f'Dataset: idataset] :')
print('===================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0]
print(data)

2 source-target

edge_index = data.edge_index
# print(edge_index.t())

3 Visual presentation using networkx

from torch_geometric.utils import to_networkxG = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

4 GCN model

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
import torch.sparseclass GCN(torch.nn.Module):def __init__(self):super().__init__()torch.manual_seed(1234)self.conv1 = GCNConv(dataset.num_features, 4, cache=False)self.conv2 = GCNConv(4, 4)self.conv3 = GCNConv(4, 2)self.classifier = Linear(2, dataset.num_classes)def forward(self, x, edge_index):h = self.conv1(x, edge_index) # edge_index 邻接矩阵h = h.tanh()h = self.conv2(h, edge_index)h = h.tanh()h = self.conv3(h, edge_index)h = h.tanh()out = self.classifier(h)return out, h

 

5 Two-dimensional vector

model = GCN()
print(model)_, h = model(data.x, data.edge_index)
visualize_embedding(h, color=data.y)

6 Training model(semi-supervised)

import timemodel = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.def train(data):optimizer.zero_grad()out, h = model(data.x, data.edge_index)  # h是两维向量,主要是为了咱们画个图loss = criterion(out[data.train_mask], data.y[data.train_mask])  # semi-supervisedloss.backward()optimizer.step()return loss, hfor epoch in range(401):loss, h = train(data)if epoch % 10 == 0:visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)time.sleep(0.3)

 

http://www.lryc.cn/news/165313.html

相关文章:

  • 每日刷题-5
  • RNN简介(深入浅出)
  • Leetcode137. 某一个数字出现一次,其余数字出现3次
  • 原子化CSS(Atomic CSS)
  • pandas 筛选数据的 8 个骚操作
  • 【随想】每日两题Day.3(实则一题)
  • 阿里后端开发:抽象建模经典案例【文末送书】
  • HarmonyOS Codelab 优秀样例——溪村小镇(ArkTS)
  • Mybatis---第二篇
  • 6.2.3 【MySQL】InnoDB的B+树索引的注意事项
  • 前端面试话术集锦第 12 篇:高频考点(Vue常考基础知识点)
  • 骨传导耳机危害有哪些?值得入手吗?
  • 网络爬虫-----初识爬虫
  • vue 功能:点击增加一项,点击减少一项
  • 我的编程学习笔记
  • 页面静态化、Freemarker入门
  • PCL (再探)点云配准精度评价指标——均方根误差
  • 【Redis速通】基础知识1 - 虚拟机配置与踩坑
  • 我的创作纪念日---从考研调剂到研一的旅程
  • Python-实现邮件发送:flask框架或django框架可以直接使用
  • 使用亚马逊云科技Amazon SageMaker,为营销活动制作广告素材
  • conda环境安装opencv带cuda版本
  • R语言中的数据结构----矩阵
  • Llama-2 推理和微调的硬件要求总结:RTX 3080 就可以微调最小模型
  • C++多线程的用法(包含线程池小项目)
  • react ant ice3 实现点击一级菜单自动打开它下面最深的第一个子菜单
  • 关于 Qt串口不同电脑出现不同串口号打开失败 的解决方法
  • 可观测性在灰度发布中的应用
  • vscode开发油猴插件环境配置指南
  • 网站不收录没排名降权怎么处理-紧急措施可恢复网站