当前位置: 首页 > news >正文

GNN code Tips

1. 重置label取值范围

 problem: otherwise occurs IndexError: target out of bounds

# reset labels value range, otherwise occurs IndexError: target out of bounds
uni_set = torch.unique(labels)
to_set = torch.tensor(list(range(len(uni_set))))
labels_reset = labels.clone().detach()
for from_val, to_val in zip(uni_set, to_set):labels_reset = torch.where(labels_reset == from_val, to_val, labels_reset)

2. 根据多个labels tensor从整体label数据中提取特定数据。

label_mask = (labels == label)  # numpy array, (100,), ([True, False, True, True])
label_indices = np.where(label_mask)[0]  # 同一标签索引, label_index, (3, ) array([0, 2, 3], dtype=int64)
negative_indices = np.where(np.logical_not(label_mask))[0]  # (97, ), 其他标签索引,作为负样本 ndarray
# anchor_pos_list = list(combinations(label_indices, 2))  # 2个元素的标签索引组合, list: 3, [(23, 66), (23, 79), (66, 79)]
extract_index_data = edge_index_mx[0: label_indices]

3. 构建Geometric GATConv和GCNConv的 edge_index

因为torch geometric 即PyG的edge_index数据shape是二维tensor,shape=[2, n]. 

# relations_ids = ['entity', 'userid', 'word'],分别读取这三个文件
def sparse_trans(datapath = None):relation = sparse.load_npz(datapath)  # (4762, 4762)all_edge_index = torch.tensor([], dtype=int)for node in range(relation.shape[0]):neighbor = torch.IntTensor(relation[node].toarray()).squeeze()  # IntTensor是torch定义的7中cpu tensor类型之一;# squeeze对数据维度进行压缩,删除所有为1的维度# del self_loop in advanceneighbor[node] = 0  # 对角线元素置0neighbor_idx = neighbor.nonzero()  # 返回非零元素的索引, size: (43, 1)neighbor_sum = neighbor_idx.size(0)  # 表示非零元素数据量,43loop = torch.tensor(node).repeat(neighbor_sum, 1)  # repeat表示按列重复node的次数edge_index_i_j = torch.cat((loop, neighbor_idx), dim=1).t()  # cat表示按dim=1按列拼接;t表示对二维矩阵进行转置, node -> neighborself_loop = torch.tensor([[node], [node]])all_edge_index = torch.cat((all_edge_index, edge_index_i_j, self_loop), dim=1)del neighbor, neighbor_idx, loop, self_loop, edge_index_i_jreturn all_edge_index  ## 返回二维矩阵,最后一维是node。 node -> nonzero neighbors

4. 为GCNConv从全部edge index抽取指定的batch edge index

因为GCNConv需要执行卷积操作convolution,index out of the size of batch, 就会报错!

  • step 1: 抽取batch nodes对应的edge index
  • step 2: 将edge index value重置 reset in the range of [0, batch_size]. 
def extract_batch_edge_idx(batch_nodes, edge_index):extract_edge_index = torch.Tensor()for i in batch_nodes:extract_edge_i = torch.Tensor()# extract 1-st row index and 2-nd row indexedge_index_bool_0 = edge_index[0, :]edge_index_bool_0 = (edge_index_bool_0 == i)if edge_index_bool_0 is None:continuebool_indices_0 = np.where(edge_index_bool_0)[0]# extract dataedge_index_0 = edge_index[0:, bool_indices_0]for j in batch_nodes:edge_index_bool_1 = edge_index_0[1, :]edge_index_bool_1 = (edge_index_bool_1 == j)if edge_index_bool_1 is None:continuebool_indices_1 = np.where(edge_index_bool_1)[0]edge_index_1 = edge_index_0[0:, bool_indices_1]extract_edge_i = torch.cat((extract_edge_i, edge_index_1), dim=1)extract_edge_index = torch.cat((extract_edge_index, extract_edge_i), dim=1)# reset index value in a specific rangeuni_set = torch.unique(extract_edge_index)to_set = torch.tensor(list(range(len(uni_set))))labels_reset = extract_edge_index.clone().detach()for from_val, to_val in zip(uni_set, to_set):labels_reset = torch.where(labels_reset == from_val, to_val, labels_reset)return labels_reset.type(torch.long)

5. 将edge index 二维tensor 向量转换为 tensor matrix格式

def relations_to_adj(filtered_multi_r_data, nb_nodes=None):relations_mx_list = []for r_data in filtered_multi_r_data:data = np.ones(r_data.shape[1])relation_mx = sp.coo_matrix((data, (r_data[0], r_data[1])), shape=(nb_nodes, nb_nodes), dtype=int)relations_mx_list.append(torch.tensor(relation_mx.todense()))return relations_mx_list

http://www.lryc.cn/news/110998.html

相关文章:

  • 物联网|按键实验---学习I/O的输入及中断的编程|函数说明的格式|如何使用CMSIS的延时|读取通过外部中断实现按键捕获代码的实现及分析-学习笔记(14)
  • Java对象的前世今生
  • Qt中JSON的使用
  • linux安装Tomcat部署jpress教程
  • 高并发负载均衡---LVS
  • 微前端中的 CSS
  • 在CSDN学Golang场景化解决方案(分布式日志系统)
  • 电脑第一次使用屏幕键盘
  • 【C#学习笔记】类型转换
  • SpringBoot+SSM实战<一>:打造高效便捷的企业级Java外卖订购系统
  • 笙默考试管理系统-MyExamTest--calculagraph
  • Mysql面试突击班索引,事务与锁
  • 数据结构——AVL树
  • AI写作宝有哪些,分享两种AI写作工具
  • 【uniapp 控制页面滑动速度】
  • 7-24 整数的分类处理 (20 分)
  • MYSQL事务同时修改单条记录
  • 安装skywalking并集成到微服务项目
  • 一支笔,一双手,一道力扣(Leetcode)做一宿
  • Kubernetes(K8s)从入门到精通系列之九:使用kubeadm工具快速安装K8s集群
  • RabbitMQ 教程 | 第11章 RabbitMQ 扩展
  • 一分钟完成centos7安装docker
  • NativePHP:使用PHP构建跨平台桌面应用的新框架
  • 删除这4个文件夹,流畅使用手机无忧
  • 使用Bert预训练模型处理序列推荐任务
  • 将word每页页眉单独设置
  • rust怎么生成随机数?
  • python-Excel数据模型文档转为MySQL数据库建表语句(需要连接数据库)-工作小记
  • 406 · 和大于S的最小子数组
  • xray的 webhook如何把它Hook住?^(* ̄(oo) ̄)^