• 【Graph Net学习】DeepWalk/Node2Vec实现Graph Embedding


    【Graph Net】【专题系列】一、DeepWalk/Node2Vec代码实战

    目录

    一、简介

    二、算法原理        

    三、代码

    四、结果及展望


    一、简介

            本文主要通过代码实战介绍基础的两种图嵌入方式DeepWalk、Node2Vec。

            DeepWalk(KDD 2014)首个影响至今的图的Embedding算法,DeepWalk算法是一种用于学习节点表示的方法,常用于网络图中的节点的嵌入表示。

    模型目标输入输出
    Word2VecWordSentenceWord Embedding
    DeepWalkNodeNode SequenceNode Embedding

    二、算法原理        

    算法流程:

            1.首先,DeepWalk算法会从一个随机的起始节点开始,比如说选择朋友A作为起点。然后,算法会从A的邻居节点中随机选择一个节点,比如说选择了B。接着,算法会再从B的邻居节点中随机选择一个节点,比如说选择了C。这样反复进行直到达到事先设定的步数。

            2.一旦完成了一次类似于“走迷宫”的遍历,DeepWalk算法会将这条路径视为一句话,其中包含了A、B和C三个节点。算法会重复这个过程,多次生成不同的句子。

            3.然后,DeepWalk算法会将这些句子作为文本输入给Word2Vec算法。

            Node2Vec:Node2Vec算法是一种能够学习网络节点表示的算法。它通过优化随机游走过程来最大化网络的邻域节点之间的相似度,从而得到每个节点的有效嵌入。

            算法流程:

    1. 首先,从图中的每个节点开始执行固定长度的随机游走。这个步骤旨在生成每个节点的上下文信息。随机游走的方法包括以一种偏好的方式回到先前访问的节点,或者探索之前未访问的节点。这通过参数p和q来调整,其中p控制返回预先访问节点的可能性,而q控制更偏向于访问较远的节点。

    2. 得到随机游走序列之后,使用Skip-gram模型训练节点嵌入。在Skip-gram模型中,我们试图预测节点的邻居节点。

    3. 在训练过程中,使用梯度下降等优化算法最小化预测错误,进而通过迭代更新嵌入向量,使得越相似的节点其嵌入向量越接近。

         4.最后经过这样的训练后,我们就可以获得每个节点的向量表示,这个向量反映了节点在网络中的位置和角色。

            直接刚代码,开箱即用。 图的相关代码可见仓库:GitHub - mapstory6788/Graph-Networks

    三、代码

    1. import torch
    2. import numpy as np
    3. import os
    4. import random
    5. import pandas as pd
    6. import scipy.sparse as sp
    7. from torch_geometric.data import Data
    8. from sklearn.preprocessing import LabelEncoder
    9. from node2vec import Node2Vec
    10. import networkx as nx
    11. from gensim.models import Word2Vec
    12. def seed_everything(seed=2023):
    13. random.seed(seed)
    14. os.environ['PYTHONHASHSEED']=str(seed)
    15. np.random.seed(seed)
    16. torch.manual_seed(seed)
    17. seed_everything()
    18. def load_cora_data(data_path = './data/cora'):
    19. content_df = pd.read_csv(os.path.join(data_path,"cora.content"), delimiter="\t", header=None)
    20. content_df.set_index(0, inplace=True)
    21. index = content_df.index.tolist()
    22. features = sp.csr_matrix(content_df.values[:,:-1], dtype=np.float32)
    23. # 处理标签
    24. labels = content_df.values[:,-1]
    25. class_encoder = LabelEncoder()
    26. labels = class_encoder.fit_transform(labels)
    27. # 读取引用关系
    28. cites_df = pd.read_csv(os.path.join(data_path,"cora.cites"), delimiter="\t", header=None)
    29. cites_df[0] = cites_df[0].astype(str)
    30. cites_df[1] = cites_df[1].astype(str)
    31. cites = [tuple(x) for x in cites_df.values]
    32. edges = [(index.index(int(cite[0])), index.index(int(cite[1]))) for cite in cites]
    33. edges = np.array(edges).T
    34. # 构造Data对象
    35. data = Data(x=torch.from_numpy(np.array(features.todense())),
    36. edge_index=torch.LongTensor(edges),
    37. y=torch.from_numpy(labels))
    38. idx_train = range(140)
    39. idx_val = range(200, 500)
    40. idx_test = range(500, 1500)
    41. # 读取Cora数据集 return geometric Data格式
    42. def index_to_mask(index, size):
    43. mask = np.zeros(size, dtype=bool)
    44. mask[index] = True
    45. return mask
    46. data.train_mask = index_to_mask(idx_train, size=labels.shape[0])
    47. data.val_mask = index_to_mask(idx_val, size=labels.shape[0])
    48. data.test_mask = index_to_mask(idx_test, size=labels.shape[0])
    49. def to_networkx(data):
    50. edge_index = data.edge_index.to(torch.device('cpu')).numpy()
    51. G = nx.DiGraph()
    52. for src, tar in edge_index.T:
    53. G.add_edge(src, tar)
    54. return G
    55. networkx_data = to_networkx(data)
    56. return data,networkx_data
    57. #获取数据:pyg_data:torch_geometric格式;networkx_data:networkx格式
    58. pyg_data,networkx_data = load_cora_data()
    59. #Node2Vec_Embedding方法
    60. def Node2Vec_run(networkx_data, dimensions=128, walk_length=30, num_walks=200):
    61. # 创建一个Node2Vec对象 #dimensions=64 embedding维度, walk_length=30 游走步长, num_walks=200 游走次数, workers=4 线程数
    62. node2vec = Node2Vec(networkx_data, dimensions=dimensions, walk_length=walk_length, num_walks=num_walks, workers=4)
    63. # 训练Node2Vec模型
    64. model = node2vec.fit(window=10, min_count=1, batch_words=4) #获得node2vec的所有内容
    65. nodes = model.wv.index_to_key # 得到所有节点的名字
    66. embeddings = model.wv[nodes] # 得到所有节点的嵌入向量
    67. return model,nodes,embeddings
    68. def DeepWalk_run(networkx_data,dimensions = 128, walk_length = 30, num_walks = 200):
    69. # 使用deepwalk算法进行graph embedding
    70. # DeepWalk算法
    71. def deepwalk(graph, num_walks, walk_length):
    72. walks = []
    73. for node in graph.nodes():
    74. if graph.degree(node) == 0:
    75. continue
    76. for _ in range(num_walks):
    77. walk = [node]
    78. target = node
    79. for _ in range(walk_length - 1):
    80. if len(list(graph.neighbors(target))) == 0: # 判断当前节点是否有邻居,如果为空邻居,则跳过当前节点
    81. continue
    82. target = random.choice(list(graph.neighbors(target)))
    83. walk.append(target)
    84. walks.append(walk)
    85. return walks
    86. walks = deepwalk(networkx_data, num_walks = num_walks, walk_length = walk_length)
    87. # 用Word2Vec训练节点向量
    88. model = Word2Vec(walks, vector_size=dimensions, window=5, min_count=0, sg=1) #参数sg=1表示选择Skip-Gram模型 window 影响着Word2Vec中词和其上下文词的最大距离
    89. nodes = model.wv.index_to_key # 得到所有节点的名字
    90. embeddings = model.wv[nodes] # 得到所有节点的嵌入向量
    91. return model,nodes,embeddings
    92. _,_,node2vec_embeddings = Node2Vec_run(networkx_data,num_walks=1)
    93. print("node2vec_embeddings :",np.array(node2vec_embeddings).shape) # print : "node2vec_embeddings : (2708, 64)"
    94. _,_,DeepWalk_embeddings = DeepWalk_run(networkx_data,num_walks=1)
    95. print("DeepWalk_embeddings :",np.array(DeepWalk_embeddings).shape) # print : "node2vec_embeddings : (2708, 64)"

    四、结果及展望

           上述是针对Cora的数据集做的Node Embedding输出,输出为:node2vec_embeddings : (2708, 128);DeepWalk_embeddings : (2708, 128)

            接下来大家就可拿到 (2708, 128)这个Embedding做各种下游了,如聚类、Net Feature等

            P.S.这些都是18年以前,NN不发达的Embedding产物,并未挖掘深层feature的embedding,接下来玩一玩NN的Graph Embedding

  • 相关阅读:
    【JavaScript】面试手撕深拷贝
    【UE 材质】制作加载图案(2)
    应力奇异,你是一个神奇的应力!
    Mock平台2-Java Spring Boot框架基础知识
    实习打怪之路:webpack概念【入口、输出、装载机、插件、模块】(引自官网)
    【IoT】产品经理:人性洞察的底层逻辑
    面向对象技术浅析
    如何开机自动清理系统临时文件
    期货价值计算方法(期货的价值怎么计算)
    QT快捷键
  • 原文地址:https://blog.csdn.net/weian4913/article/details/133082579