当前位置: 首页 > news >正文

PyG-GAT-Cora(在Cora数据集上应用GAT做节点分类)

文章目录

  • model.py
  • main.py
  • 参数设置
  • 运行图

model.py

import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
class gat_cls(nn.Module):def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):super(gat_cls,self).__init__()self.conv1 = GATConv(in_dim,hid_dim)self.conv2 = GATConv(hid_dim,hid_dim)self.fc = nn.Linear(hid_dim,out_dim)self.relu  = nn.ReLU()self.dropout_size = dropout_sizedef forward(self,x,edge_index):x = self.conv1(x,edge_index)x = F.dropout(x,p=self.dropout_size,training=self.training)x = self.relu(x)x = self.conv2(x,edge_index)x = self.relu(x)x = self.fc(x)return x

main.py

import torch
import torch.nn as nn
from torch_geometric.datasets import Planetoid
from model import gat_cls
import torch.optim as optim
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset[0])
cora_data = dataset[0]epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7net = gat_cls(cora_data.x.shape[1],hidden_dim,output_dim)
optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=weight_decay)
#optimizer = optim.SGD(net.parameters(),lr = lr,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
net.train()
for epoch in range(epochs):out = net(cora_data.x,cora_data.edge_index)optimizer.zero_grad()loss_train = criterion(out[cora_data.train_mask],cora_data.y[cora_data.train_mask])loss_val   = criterion(out[cora_data.val_mask],cora_data.y[cora_data.val_mask])loss_train.backward()print('epoch',epoch+1,'loss-train {:.2f}'.format(loss_train),'loss-val {:.2f}'.format(loss_val))optimizer.step()net.eval()
out = net(cora_data.x,cora_data.edge_index)
loss_test = criterion(out[cora_data.test_mask],cora_data.y[cora_data.test_mask])
_,pred = torch.max(out,dim=1)
pred_label = pred[cora_data.test_mask]
true_label = cora_data.y[cora_data.test_mask]
acc = sum(pred_label==true_label)/len(pred_label)
print("****************Begin Testing****************")
print('loss-test {:.2f}'.format(loss_test),'acc {:.2f}'.format(acc))

参数设置

epochs = 50
lr = 1e-3
weight_decay = 5e-3
momentum = 0.5
hidden_dim = 128
output_dim = 7

运行图

在这里插入图片描述

http://www.lryc.cn/news/170471.html

相关文章:

  • java专项练习(验证码)
  • MS1861 视频处理与显示控制器 HDMI转MIPI LVDS转MIPI带旋转功能 图像带缩放,旋转,锐化
  • 广州华锐互动:利用VR复原文化遗址,沉浸式体验历史文物古迹的魅力
  • 微信小程序——事件监听
  • View绘制流程的源码所得
  • 企业级数据仓库-理论知识
  • 解决flutter不识别yaml里面配置的git项目
  • rust结构体
  • Python - 小玩意 - 键盘记录器
  • msvcp71.dll丢失的解决方法分享,全面分析msvcp71.dll丢失原因
  • stm32----ADC模数转换
  • Unity SteamVR 开发教程:用摇杆/触摸板控制人物持续移动(2.x 以上版本)
  • 04条件构造器和常用接口
  • 什么是HTTP状态码?常见的HTTP状态码有哪些?
  • vue3的双向绑定原理分析
  • MySQL数据库时间计算的用法
  • 应用在儿童平板防蓝光中的LED防蓝光灯珠
  • BERT 快速理解——思路简单描述
  • 二叉树实现的相关函数
  • Redis面试题(二)
  • STP介绍
  • numpy 和 tensorflow 中的各种乘法(点乘和矩阵乘)
  • (图论) 1020. 飞地的数量 ——【Leetcode每日一题】
  • c++ 重载、重写、覆盖
  • Python异步编程高并发执行爬虫采集,用回调函数解析响应
  • SpriteKit与Swift配合:打造您的第一个简易RPG游戏的步骤指南
  • 服务网格的面临挑战:探讨服务网格实施中可能遇到的问题和解决方案
  • leetcode61 旋转链表
  • 【学习笔记】各类基于决策单调性的dp优化
  • 【C++】构造函数初始化列表 ⑤ ( 匿名对象 生命周期 | 构造函数 中 不能调用 构造函数 )