• 【Graph Net学习】GNN/GCN代码实战


    【Graph Net】【专题系列】三、GNN/GCN代码实战

    目录

    一、简介

    二、代码

    三、结果与讨论

    四、展望

    本专栏更多好文欢迎点击下方连接:

    【Graph Net系列文章】


    一、简介

            GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。图的相关代码可见仓库:GitHub - mapstory6788/Graph-Networks

    二、代码

    1. import time
    2. import random
    3. import os
    4. import numpy as np
    5. import math
    6. from torch.nn.parameter import Parameter
    7. from torch.nn.modules.module import Module
    8. import torch
    9. import torch.nn as nn
    10. import torch.nn.functional as F
    11. import torch.optim as optim
    12. import scipy.sparse as sp
    13. #配置项
    14. class configs():
    15. def __init__(self):
    16. # Data
    17. self.data_path = r'E:\code\Graph\data'
    18. self.save_model_dir = r'\code\Graph'
    19. self.model_name = r'GCN' #GNN/GCN
    20. self.seed = 2023
    21. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    22. self.batch_size = 64
    23. self.epoch = 200
    24. self.in_features = 1433 #core ~ feature:1433
    25. self.hidden_features = 16 # 隐层数量
    26. self.output_features = 8 # core~paper-point~ 8类
    27. self.learning_rate = 0.01
    28. self.dropout = 0.5
    29. self.istrain = True
    30. self.istest = True
    31. cfg = configs()
    32. def seed_everything(seed=2023):
    33. random.seed(seed)
    34. os.environ['PYTHONHASHSEED']=str(seed)
    35. np.random.seed(seed)
    36. torch.manual_seed(seed)
    37. seed_everything(seed = cfg.seed)
    38. #数据
    39. class Graph_Data_Loader():
    40. def __init__(self):
    41. self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()
    42. self.adj = self.adj.to(cfg.device)
    43. self.features = self.features.to(cfg.device)
    44. self.labels = self.labels.to(cfg.device)
    45. self.idx_train = self.idx_train.to(cfg.device)
    46. self.idx_val = self.idx_val.to(cfg.device)
    47. self.idx_test = self.idx_test.to(cfg.device)
    48. def load_data(self,path=cfg.data_path, dataset="cora"):
    49. """Load citation network dataset (cora only for now)"""
    50. print('Loading {} dataset...'.format(dataset))
    51. idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),
    52. dtype=np.dtype(str))
    53. features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    54. labels = self.encode_onehot(idx_features_labels[:, -1])
    55. # build graph
    56. idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
    57. idx_map = {j: i for i, j in enumerate(idx)}
    58. edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),
    59. dtype=np.int32)
    60. edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
    61. dtype=np.int32).reshape(edges_unordered.shape)
    62. adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
    63. shape=(labels.shape[0], labels.shape[0]),
    64. dtype=np.float32)
    65. # build symmetric adjacency matrix
    66. adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    67. features = self.normalize(features)
    68. adj = self.normalize(adj + sp.eye(adj.shape[0]))
    69. idx_train = range(140)
    70. idx_val = range(200, 500)
    71. idx_test = range(500, 1500)
    72. features = torch.FloatTensor(np.array(features.todense()))
    73. labels = torch.LongTensor(np.where(labels)[1])
    74. adj = self.sparse_mx_to_torch_sparse_tensor(adj)
    75. idx_train = torch.LongTensor(idx_train)
    76. idx_val = torch.LongTensor(idx_val)
    77. idx_test = torch.LongTensor(idx_test)
    78. return adj, features, labels, idx_train, idx_val, idx_test
    79. def encode_onehot(self,labels):
    80. classes = set(labels)
    81. classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
    82. enumerate(classes)}
    83. labels_onehot = np.array(list(map(classes_dict.get, labels)),
    84. dtype=np.int32)
    85. return labels_onehot
    86. def normalize(self,mx):
    87. """Row-normalize sparse matrix"""
    88. rowsum = np.array(mx.sum(1))
    89. r_inv = np.power(rowsum, -1).flatten()
    90. r_inv[np.isinf(r_inv)] = 0.
    91. r_mat_inv = sp.diags(r_inv)
    92. mx = r_mat_inv.dot(mx)
    93. return mx
    94. def sparse_mx_to_torch_sparse_tensor(self,sparse_mx):
    95. """Convert a scipy sparse matrix to a torch sparse tensor."""
    96. sparse_mx = sparse_mx.tocoo().astype(np.float32)
    97. indices = torch.from_numpy(
    98. np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    99. values = torch.from_numpy(sparse_mx.data)
    100. shape = torch.Size(sparse_mx.shape)
    101. return torch.sparse.FloatTensor(indices, values, shape)
    102. #精度评价指标
    103. def accuracy(output, labels):
    104. preds = output.max(1)[1].type_as(labels)
    105. correct = preds.eq(labels).double()
    106. correct = correct.sum()
    107. return correct / len(labels)
    108. #模型
    109. #01-GNN
    110. class GNNLayer(nn.Module):
    111. def __init__(self, in_features, output_features):
    112. super(GNNLayer, self).__init__()
    113. self.linear = nn.Linear(in_features, output_features)
    114. def forward(self, adj_matrix, features):
    115. hidden_features = torch.matmul(adj_matrix, features) # GNN公式:H' = A * H
    116. hidden_features = self.linear(hidden_features) # 使用线性变换
    117. hidden_features = F.relu(hidden_features) # 使用ReLU作为激活函数
    118. return hidden_features
    119. class GNN(nn.Module):
    120. def __init__(self, in_features, hidden_features, output_features, num_layers=2):
    121. super(GNN, self).__init__()
    122. #输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layers
    123. self.layers = nn.ModuleList(
    124. [GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i in
    125. range(num_layers)])
    126. self.output_layer = nn.Linear(hidden_features, output_features)
    127. def forward(self, adj_matrix, features):
    128. hidden_features = features
    129. for layer in self.layers:
    130. hidden_features = layer(adj_matrix, hidden_features)
    131. output = self.output_layer(hidden_features)
    132. return F.log_softmax(output,dim=1)
    133. #02-GCN
    134. class GraphConvolution(Module):
    135. """
    136. Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    137. """
    138. def __init__(self, in_features, out_features, bias=True):
    139. super(GraphConvolution, self).__init__()
    140. self.in_features = in_features
    141. self.out_features = out_features
    142. self.weight = Parameter(torch.FloatTensor(in_features, out_features))
    143. if bias:
    144. self.bias = Parameter(torch.FloatTensor(out_features))
    145. else:
    146. self.register_parameter('bias', None)
    147. self.reset_parameters()
    148. def reset_parameters(self):
    149. stdv = 1. / math.sqrt(self.weight.size(1))
    150. self.weight.data.uniform_(-stdv, stdv)
    151. if self.bias is not None:
    152. self.bias.data.uniform_(-stdv, stdv)
    153. def forward(self, input, adj):
    154. support = torch.mm(input, self.weight)
    155. output = torch.spmm(adj, support)
    156. if self.bias is not None:
    157. return output + self.bias
    158. else:
    159. return output
    160. def __repr__(self):
    161. return self.__class__.__name__ + ' (' \
    162. + str(self.in_features) + ' -> ' \
    163. + str(self.out_features) + ')'
    164. class GCN(nn.Module):
    165. def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):
    166. super(GCN, self).__init__()
    167. self.gc1 = GraphConvolution(in_features, hidden_features)
    168. self.gc2 = GraphConvolution(hidden_features, output_features)
    169. self.dropout = dropout
    170. def forward(self, adj_matrix, features):
    171. x = F.relu(self.gc1(features, adj_matrix))
    172. x = F.dropout(x, self.dropout, training=self.training)
    173. x = self.gc2(x, adj_matrix)
    174. return F.log_softmax(x, dim=1)
    175. class graph_run():
    176. def train(self):
    177. t = time.time()
    178. #Create Train Processing
    179. all_data = Graph_Data_Loader()
    180. #创建一个模型
    181. model = eval(cfg.model_name)(in_features=cfg.in_features,
    182. hidden_features=cfg.hidden_features,
    183. output_features=cfg.output_features).to(cfg.device)
    184. optimizer = optim.Adam(model.parameters(),
    185. lr=cfg.learning_rate, weight_decay=5e-4)
    186. #Train
    187. model.train()
    188. for epoch in range(cfg.epoch):
    189. optimizer.zero_grad()
    190. output = model(all_data.adj, all_data.features)
    191. loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])
    192. acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])
    193. loss_train.backward()
    194. optimizer.step()
    195. loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])
    196. acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])
    197. print('Epoch: {:04d}'.format(epoch + 1),
    198. 'loss_train: {:.4f}'.format(loss_train.item()),
    199. 'acc_train: {:.4f}'.format(acc_train.item()),
    200. 'loss_val: {:.4f}'.format(loss_val.item()),
    201. 'acc_val: {:.4f}'.format(acc_val.item()),
    202. 'time: {:.4f}s'.format(time.time() - t))
    203. torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存
    204. def infer(self):
    205. #Create Test Processing
    206. all_data = Graph_Data_Loader()
    207. model_path = os.path.join(cfg.save_model_dir, 'latest.pth')
    208. model = torch.load(model_path, map_location=torch.device(cfg.device))
    209. model.eval()
    210. output = model(all_data.adj,all_data.features)
    211. loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])
    212. acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])
    213. print("Test set results:",
    214. "loss= {:.4f}".format(loss_test.item()),
    215. "accuracy= {:.4f}".format(acc_test.item()))
    216. if __name__ == '__main__':
    217. mygraph = graph_run()
    218. if cfg.istrain == True:
    219. mygraph.train()
    220. if cfg.istest == True:
    221. mygraph.infer()

    三、结果与讨论

            需要从网上下载cora数据集,数据组织形式如下图。

            测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~

    ModelParamsGFLOPs
    GNN23.352K126.258M
    ModelCora(/train/val/test)
    GNN1.0000/0.7800/0.7620
    GCN0.9714/0.7767/0.8290

    四、展望

            未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。

  • 相关阅读:
    总结:Spring Boot之@Value
    log is判断引发的一系列事件,哭哭
    【R语言文本挖掘】:tidy数据格式及词频计算
    前端之使用webpack打包TS
    HDLbits:Fsm onehot
    面试官:微服务通讯方式有哪些?
    ElasticSearch 同步数据变少了
    4.zigbee开发,串口的应用,ADC采集
    软件 | 快速计算网络自然连通度评估群落稳定性
    uview的input框,clear没反应的问题
  • 原文地址:https://blog.csdn.net/weian4913/article/details/132991337