当前位置: 首页 > news >正文

使用R-GCN处理异质图ACM的demo

加载和处理数据集

import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0]  # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')]  # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)

定义 R-GCN 模型

from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type))  # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)

训练和测试函数

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

完整代码

import torch
from torch_geometric.datasets import HGBDataset
from torch_geometric.transforms import RandomLinkSplit# 加载ACM数据集,这是一个包含论文(paper)、主题(subject)以及它们之间关系的异质图数据集
dataset = HGBDataset(root='/tmp/HGB', name='ACM')
data = dataset[0]  # 使用数据集中的第一个图# 利用RandomLinkSplit将数据集随机划分为训练、验证和测试集
transform = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True, split_labels=True,neg_sampling_ratio=1.0,edge_types=[('paper', 'has-subject', 'subject')]  # 指定需要划分的边类型
)
train_data, val_data, test_data = transform(data)from torch_geometric.nn import RGCNConv
import torch.nn.functional as Fclass RGCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, num_relations):super().__init__()# 定义两层RGCN层,每层处理图中的不同关系类型self.conv1 = RGCNConv(in_channels, hidden_channels, num_relations=num_relations)self.conv2 = RGCNConv(hidden_channels, out_channels, num_relations=num_relations)def forward(self, x, edge_index, edge_type):# 使用ReLU激活函数处理第一层的输出x = F.relu(self.conv1(x, edge_index, edge_type))# 第二层RGCN处理并输出节点特征x = self.conv2(x, edge_index, edge_type)return xnum_relations = len(torch.unique(data.edge_type))  # 计算图中不同关系类型的数量
model = RGCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=32, num_relations=num_relations)optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()def train():model.train()optimizer.zero_grad()# 前向传播,计算模型在训练数据上的输出z = model(train_data.x, train_data.edge_index, train_data.edge_type)# 计算二进制交叉熵损失loss = criterion(z[train_data.edge_label_index], train_data.edge_label.float())loss.backward()optimizer.step()return loss.item()def test(data):model.eval()with torch.no_grad():z = model(data.x, data.edge_index, data.edge_type)loss = criterion(z[data.edge_label_index], data.edge_label.float())# 根据sigmoid阈值判断预测为正类或负类pred = z.sigmoid() > 0.5# 计算准确率correct = pred == data.edge_label.bool()acc = int(correct.sum()) / int(correct.size(0))return loss.item(), accfor epoch in range(100):loss = train()val_loss, val_acc = test(val_data)print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')test_loss, test_acc = test(test_data)
print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
http://www.lryc.cn/news/463756.html

相关文章:

  • 征程 6E DISPLAY 功能介绍及上手实践
  • 安卓窗口wms/input小知识NO_INPUT_CHANNEL剖析
  • 【2024最新版】Win10下 Java环境变量配置----适合入门小白
  • Servlet 生命周期详解及案例演示(SpringMVC底层实现)
  • 2024 kali系统2024版本,可视化界面汉化教程(需要命令行更改),英文版切换为中文版,基于Debian创建的kali虚拟机
  • 深入理解 CMake 中的 INCLUDE_DIRECTORIES 与 target_include_directories
  • 【不知道原因的问题】java.lang.AbstractMethodError
  • 分布式篇(分布式事务)(持续更新迭代)
  • [Linux] 逐层深入理解文件系统 (2)—— 文件重定向
  • html+css+js实现Badge 标记
  • 纯css 轮播图片,鼠标移入暂停 移除继续
  • iOS GCD的基本使用
  • 如何设计开发RTSP直播播放器?
  • Java基础系列-一文搞懂自定义排序
  • 扫普通链接二维码打开小程序
  • 计算机储存与分区
  • OpenCV之换脸技术:一场面部识别的奇妙之旅
  • Linux学习笔记9 文件系统的基础
  • Android OpenGL粒子特效
  • 5 -《本地部署开源大模型》在Ubuntu 22.04系统下ChatGLM3-6B高效微调实战
  • dpkg:错误:另外一个进程已经为dpkg前端锁加锁
  • 基于SSM服装定制系统的设计
  • RK3588开发笔记-usb3.0 xhci-hcd控制器挂死问题解决
  • 深入解析TCP/IP协议:网络通信的基石
  • 基于微信小程序的汽车预约维修系统(lw+演示+源码+运行)
  • wifi、热点密码破解 - python
  • bean的实例化2024年10月17日
  • 告别ELK,APO提供基于ClickHouse开箱即用的高效日志方案——APO 0.6.0发布
  • Excel使用技巧:定位Ctrl+G +公式+原位填充 Ctrl+Enter快速填充数据(处理合并单元格)
  • JAVA学习-练习试用Java实现“成绩归类”