当前位置: 首页 > news >正文

Graph U-Net Code【图分类】

1. main.py


# GNet是需要用到的model
net = GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度,类别数,参数
trainer = Trainer(args, net, G_data) #开始训练数据
# 正式开始训练数据
trainer.train()

2. network.py

class GNet(nn.Module):def __init__(self, in_dim, n_classes, args):super(GNet, self).__init__()self.n_act = getattr(nn, args.act_n)()# getattr() 是 Python 内置的一个函数,可以用来获取一个对象的属性值或方法self.c_act = getattr(nn, args.act_c)()# print('GNet1: in_dim=', in_dim, 'n_class=',n_classes)  # GNet1: in_dim= 82 n_class= 2"用的是GCN的框架,输入分别是feat dim、layer dim、network act、drop net(net表示GCN网络本身的参数)"self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)self.g_unet = GraphUnet(args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n)"""nn.Linear定义一个神经网络的线性层,方法如下:torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)"""self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)self.out_l_2 = nn.Linear(args.h_dim, n_classes)"nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活"self.out_drop = nn.Dropout(p=args.drop_c)Initializer.weights_init(self)def forward(self, gs, hs, labels):print('GNet2: gs=',type(gs), len(gs), 'hs=',type(hs), len(hs), 'labels:',type(labels),labels.shape)# GNet2: gs= <class 'list'> 32 hs= <class 'list'> 32 labels: <class 'torch.Tensor'> torch.Size([32])hs = self.embed(gs, hs)print('GNet2: hs=', type(hs), hs.shape)logits = self.classify(hs)return self.metric(logits, labels)

3. trainer.py

class Trainer:"init初始化,输入分别是arg参数、gcn net、graph Data,将这些装进self里面"def __init__(self, args, net, G_data):self.args = argsself.net = netself.feat_dim = G_data.feat_dimself.fold_idx = G_data.fold_idxself.init(args, G_data.train_gs, G_data.test_gs)# 若是有显卡,则用显卡跑if torch.cuda.is_available():self.net.cuda()"初始化——开始训练数据"def init(self, args, train_gs, test_gs):print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))# 分成训练集和测试集,记载数据train_data = GraphData(train_gs, self.feat_dim)test_data = GraphData(test_gs, self.feat_dim)# DataLoader 为pytorch 内部类,此时只需要指定trainset, batch_size, shuffle, num_workers, ...等self.train_d = train_data.loader(self.args.batch, True)self.test_d = test_data.loader(self.args.batch, False)self.optimizer = optim.Adam(self.net.parameters(), lr=self.args.lr, amsgrad=True,weight_decay=0.0008)
    def train(self):max_acc = 0.0train_str = 'Train epoch %d: loss %.5f acc %.5f'test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'line_str = '%d:\t%.5f\n'for e_id in range(self.args.num_epochs):self.net.train()# 从每个epoch开始训练loss, acc = self.run_epoch(e_id, self.train_d, self.net, self.optimizer)print(train_str % (e_id, loss, acc))with torch.no_grad():self.net.eval()loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)max_acc = max(max_acc, acc)print(test_str % (e_id, loss, acc, max_acc))with open(self.args.acc_file, 'a+') as f:f.write(line_str % (self.fold_idx, max_acc))
    def run_epoch(self, epoch, data, model, optimizer):#self.run_epoch(e_id, self.train_d, self.net, self.optimizer)losses, accs, n_samples = [], [], 0for batch in tqdm(data, desc=str(epoch), unit='b'):cur_len, gs, hs, ys = batchgs, hs, ys = map(self.to_cuda, [gs, hs, ys])loss, acc = model(gs, hs, ys)losses.append(loss*cur_len)accs.append(acc*cur_len)n_samples += cur_lenif optimizer is not None:optimizer.zero_grad()loss.backward()optimizer.step()avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samplesreturn avg_loss.item(), avg_acc.item()

不懂

class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_features"""为啥要这么做???5555555555555555555555555555"""self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return output
http://www.lryc.cn/news/214000.html

相关文章:

  • PTA 秀恩爱分得快(树)
  • 文心一言4.0对比ChatGPT4.0有什么优势?
  • 美观且可以很方便自定义的MATLAB绘图颜色
  • 基于jsp,ssm物流快递管理系统
  • 陪诊系统|挂号陪护搭建二开陪诊师入驻就医小程序
  • 恒驰服务 | 华为云数据使能专家服务offering之大数据建设
  • 轻量级狂雨小说cms系统源码 v1.5.2 基于ThinkPHP5.1+MySQL
  • Leetcode刷题详解——Pow(x, n)
  • 计算机毕业设计选题推荐-校园失物招领微信小程序/安卓APP-项目实战
  • 人工智能基础_机器学习011_梯度下降概念_梯度下降步骤_函数与导函数求解最优解---人工智能工作笔记0051
  • 开放式耳机能保护听力吗,开放式耳机跟骨传导耳机哪个更好?
  • 【Qt之QLocale】使用
  • 维修服务预约小程序的效果如何
  • 前端架构体系调研整理汇总
  • DrawerLayout的点击事件会穿透到底部,如何拦截?
  • 在Spring boot中 使用JWT和过滤器实现登录认证
  • 天堂2如何对版本里面的内容进行修改
  • 代码随想录Day33 LeetCode T62不同路径 LeetCode T63 不同路径II
  • 【计算机网络】分层模型和应用协议
  • Python框架之Flask入门和视图
  • streamWriter.WriteLine
  • 一键添加色彩变幻效果,视频剪辑从未如此简单!
  • Linux的简介和环境搭建
  • 你看现在的程序员,是怎么解bug的
  • CSS3背景样式
  • JAVA同城服务同城圈子真人躲猫猫系统的玩法流程
  • C++继承——圆形和圆柱体
  • 致远OA wpsAssistServlet任意文件上传漏洞复现 [附POC]
  • Java规则引擎2.1.8版本新增功能说明
  • 系列四十、请谈一下Spring中事务的传播行为