• GraphSAGE 源代码 -- 分图训练


    源代码下载链接:GitHub - twjiang/graphSAGE-pytorch: A PyTorch implementation of GraphSAGE. This package contains a PyTorch implementation of GraphSAGE.

    1 使用数据集介绍

    数据集使用cora;

            图数据集,包含2708篇科学出版物, 5429条边,总共7种类别;

            每篇论文至少引用一篇论文或被至少一篇论文引用(即至少有一条出边或至少有一条入边,也就是样本点之间存在联系,没有任何一个样本点与其他样本点完全没联系。如果将样本点看做图中的点,则这是一个连通的图,不存在孤立点);

            在词干堵塞和去除词尾后,且文档频率小于10的所有单词都被删除后,只剩下1433个单词;

            Cora 数据集中主要包含两个文件:cora.content 和 cora.cites;

    cora.content内容展示:

            主要包含三部分: 论文ID, 论文特征表示, 论文类别

    1. 31336 0 0 0 0 0 0 0 0 ... 0 Neural_Networks
    2. 1061127 0 0 0 0 ... 0 0 0 0 0 Rule_Learning
    3. 1106406 0 ... 0 0 0 0 0 0 0 0 0 Reinforcement_Learning

    cora.cites内容展示: 两组论文编号,表示其之间有边;

    1. 35 1033
    2. 35 103482
    3. 35 103515
    4. 35 1050679
    5. 35 1103960

    2 主函数部分 src/main.py

            第一部分设置training set

    1. import sys
    2. import os
    3. import torch
    4. import argparse
    5. import pyhocon
    6. import random
    7. from src.dataCenter import *
    8. from src.utils import *
    9. from src.models import *
    10. parser = argparse.ArgumentParser(description='pytorch version of GraphSAGE')
    11. parser.add_argument('--dataSet', type=str, default='cora')
    12. parser.add_argument('--agg_func', type=str, default='MEAN')
    13. parser.add_argument('--epochs', type=int, default=50)
    14. parser.add_argument('--b_sz', type=int, default=20)
    15. parser.add_argument('--seed', type=int, default=824)
    16. parser.add_argument('--cuda', action='store_true',
    17. help='use CUDA')
    18. parser.add_argument('--gcn', action='store_true')
    19. parser.add_argument('--learn_method', type=str, default='sup')
    20. parser.add_argument('--unsup_loss', type=str, default='normal')
    21. parser.add_argument('--max_vali_f1', type=float, default=0)
    22. parser.add_argument('--name', type=str, default='debug')
    23. parser.add_argument('--config', type=str, default='experiments.conf')
    24. args = parser.parse_args()
    25. if torch.cuda.is_available():
    26. if not args.cuda:
    27. print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    28. else:
    29. device_id = torch.cuda.current_device()
    30. print('using device', device_id, torch.cuda.get_device_name(device_id))
    31. device = torch.device("cuda" if args.cuda else "cpu")
    32. print('DEVICE:', device)
    33. if __name__ == '__main__':
    34. random.seed(args.seed)
    35. np.random.seed(args.seed)
    36. torch.manual_seed(args.seed)
    37. torch.cuda.manual_seed_all(args.seed)
    38. # load config file
    39. # 导入training set
    40. config = pyhocon.ConfigFactory.parse_file(args.config)

            读取数据:

    1. # load data
    2. ds = args.dataSet
    3. dataCenter = DataCenter(config)
    4. dataCenter.load_dataSet(ds)
    5. # 取出节点的特征
    6. features = torch.FloatTensor(getattr(dataCenter, ds+'_feats')).to(device)

            设置graphsage

    1. graphSage = GraphSage(config['setting.num_layers'], features.size(1), config['setting.hidden_emb_size'], features, getattr(dataCenter, ds+'_adj_lists'), device, gcn=args.gcn, agg_func=args.agg_func)
    2. graphSage.to(device)
    3. # 定义label的数量 7
    4. num_labels = len(set(getattr(dataCenter, ds+'_labels')))
    5. # graphsage输出特征后,经过分类器
    6. classification = Classification(config['setting.hidden_emb_size'], num_labels)
    7. classification.to(device)

            因为graphsage涉及到有监督和无监督,这里设置一个无监督的loss;

    1. # 目前采用的是有监督学习模型,这里可以不看
    2. unsupervised_loss = UnsupervisedLoss(getattr(dataCenter, ds+'_adj_lists'), getattr(dataCenter, ds+'_train'), device)

            然后是训练模型;

            训练了两个模型,第一个是graphsage, 第二个是分类模型

    1. # 判定学习类型,这里是采用有监督模型的
    2. if args.learn_method == 'sup':
    3. print('GraphSage with Supervised Learning')
    4. elif args.learn_method == 'plus_unsup':
    5. print('GraphSage with Supervised Learning plus Net Unsupervised Learning')
    6. else:
    7. print('GraphSage with Net Unsupervised Learning')
    8. for epoch in range(args.epochs):
    9. print('----------------------EPOCH %d-----------------------' % epoch)
    10. # apply_model模型运行的函数
    11. graphSage, classification = apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, args.b_sz, args.unsup_loss, device, args.learn_method)
    12. if (epoch+1) % 2 == 0 and args.learn_method == 'unsup':
    13. classification, args.max_vali_f1 = train_classification(dataCenter, graphSage, classification, ds, device, args.max_vali_f1, args.name)
    14. if args.learn_method != 'unsup':
    15. args.max_vali_f1 = evaluate(dataCenter, ds, graphSage, classification, device, args.max_vali_f1, args.name, epoch)

    3 加载数据集部分 src/dataCenter.py

            首先__init__()

            读取数据load_dataSet()

    1. import sys
    2. import os
    3. from collections import defaultdict
    4. import numpy as np
    5. class DataCenter(object):
    6. """docstring for DataCenter"""
    7. def __init__(self, config):
    8. super(DataCenter, self).__init__()
    9. self.config = config
    10. def load_dataSet(self, dataSet='cora'):
    11. if dataSet == 'cora':
    12. # cora_content_file = self.config['file_path.cora_content']
    13. # cora_cite_file = self.config['file_path.cora_cite']
    14. cora_content_file = '/Users/qiaoboyu/Desktop/pythonProject1/graphSAGE-pytorch-master/cora/cora.content'
    15. cora_cite_file = '/Users/qiaoboyu/Desktop/pythonProject1/graphSAGE-pytorch-master/cora/cora.cites'
    16. feat_data = []
    17. labels = [] # label sequence of node
    18. node_map = {} # map node to Node_ID
    19. label_map = {} # map label to Label_ID
    20. with open(cora_content_file) as fp:
    21. for i,line in enumerate(fp):
    22. info = line.strip().split()
    23. feat_data.append([float(x) for x in info[1:-1]])
    24. node_map[info[0]] = i
    25. if not info[-1] in label_map:
    26. label_map[info[-1]] = len(label_map)
    27. labels.append(label_map[info[-1]])
    28. # (2708, 1433)
    29. feat_data = np.asarray(feat_data)
    30. # (2708,)
    31. labels = np.asarray(labels, dtype=np.int64)
    32. adj_lists = defaultdict(set)
    33. with open(cora_cite_file) as fp:
    34. for i,line in enumerate(fp):
    35. info = line.strip().split()
    36. assert len(info) == 2
    37. paper1 = node_map[info[0]]
    38. paper2 = node_map[info[1]]
    39. adj_lists[paper1].add(paper2) # defaultdict(set, {163: {402, 659}, 402: {163}})
    40. adj_lists[paper2].add(paper1)
    41. assert len(feat_data) == len(labels) == len(adj_lists) # 2708
    42. test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])
    43. setattr(self, dataSet+'_test', test_indexs)
    44. setattr(self, dataSet+'_val', val_indexs)
    45. setattr(self, dataSet+'_train', train_indexs)
    46. setattr(self, dataSet+'_feats', feat_data)
    47. setattr(self, dataSet+'_labels', labels)
    48. setattr(self, dataSet+'_adj_lists', adj_lists)
    49. elif dataSet == 'pubmed':
    50. pubmed_content_file = self.config['file_path.pubmed_paper']
    51. pubmed_cite_file = self.config['file_path.pubmed_cites']
    52. feat_data = []
    53. labels = [] # label sequence of node
    54. node_map = {} # map node to Node_ID
    55. with open(pubmed_content_file) as fp:
    56. fp.readline()
    57. feat_map = {entry.split(":")[1]:i-1 for i,entry in enumerate(fp.readline().split("\t"))}
    58. for i, line in enumerate(fp):
    59. info = line.split("\t")
    60. node_map[info[0]] = i
    61. labels.append(int(info[1].split("=")[1])-1)
    62. tmp_list = np.zeros(len(feat_map)-2)
    63. for word_info in info[2:-1]:
    64. word_info = word_info.split("=")
    65. tmp_list[feat_map[word_info[0]]] = float(word_info[1])
    66. feat_data.append(tmp_list)
    67. feat_data = np.asarray(feat_data)
    68. labels = np.asarray(labels, dtype=np.int64)
    69. adj_lists = defaultdict(set)
    70. with open(pubmed_cite_file) as fp:
    71. fp.readline()
    72. fp.readline()
    73. for line in fp:
    74. info = line.strip().split("\t")
    75. paper1 = node_map[info[1].split(":")[1]]
    76. paper2 = node_map[info[-1].split(":")[1]]
    77. adj_lists[paper1].add(paper2)
    78. adj_lists[paper2].add(paper1)
    79. assert len(feat_data) == len(labels) == len(adj_lists)
    80. test_indexs, val_indexs, train_indexs = self._split_data(feat_data.shape[0])
    81. setattr(self, dataSet+'_test', test_indexs)
    82. setattr(self, dataSet+'_val', val_indexs)
    83. setattr(self, dataSet+'_train', train_indexs)
    84. setattr(self, dataSet+'_feats', feat_data)
    85. setattr(self, dataSet+'_labels', labels)
    86. setattr(self, dataSet+'_adj_lists', adj_lists)

            分割数据集;

    1. def _split_data(self, num_nodes, test_split = 3, val_split = 6):
    2. rand_indices = np.random.permutation(num_nodes)
    3. test_size = num_nodes // test_split #902
    4. val_size = num_nodes // val_split # 451
    5. train_size = num_nodes - (test_size + val_size) # 1355
    6. test_indexs = rand_indices[:test_size] #随机打乱的序号
    7. val_indexs = rand_indices[test_size:(test_size+val_size)]
    8. train_indexs = rand_indices[(test_size+val_size):]
    9. return test_indexs, val_indexs, train_indexs

    4 定义graphsage模型models.py

            首先是定义__init__() 方法

    1. class GraphSage(nn.Module):
    2. """docstring for GraphSage"""
    3. def __init__(self, num_layers, input_size, out_size, raw_features, adj_lists, device, gcn=False, agg_func='MEAN'):
    4. super(GraphSage, self).__init__()
    5. self.input_size = input_size # 1433
    6. self.out_size = out_size #128
    7. self.num_layers = num_layers #2
    8. self.gcn = gcn # False
    9. self.device = device
    10. self.agg_func = agg_func # MEAN 聚合函数(采样后要聚合,这里采用mean的方法)
    11. self.raw_features = raw_features #torch.Size([2708, 1433]) 点的特征
    12. self.adj_lists = adj_lists # 边的连接,现在还不是邻接矩阵
    13. for index in range(1, num_layers+1): # graphsage每一层的构造
    14. layer_size = out_size if index != 1 else input_size # 第一层应该是原始维度,第二层可以是更新后的维度
    15. setattr(self, 'sage_layer'+str(index), SageLayer(layer_size, out_size, gcn=self.gcn)) # SageLayer类 定义图的每一层

     

            权重的定义,第一层 W = [128, 1433*2]

                                   第二层 W = [128, 128*2]

            比如聚合节点a的表示,首先第一层聚合标绿色节点部分,得到128维向量,即黄色节点部分;

                    第二层聚合黄色节点部分,最终得到节点a的聚合后表示;

             定义forward()

    1. def forward(self, nodes_batch):
    2. """
    3. Generates embeddings for a batch of nodes.
    4. nodes_batch -- batch of nodes to learn the embeddings
    5. """
    6. # 将节点转化为list类型
    7. lower_layer_nodes = list(nodes_batch)
    8. # 第一次放入的节点
    9. nodes_batch_layers = [(lower_layer_nodes,)]
    10. # self.dc.logger.info('get_unique_neighs.')
    11. # 遍历每一层的graphsage
    12. for i in range(self.num_layers):
    13. # 得到初始节点的邻域节点_get_unique_neighs_list
    14. lower_samp_neighs, lower_layer_nodes_dict, lower_layer_nodes= self._get_unique_neighs_list(lower_layer_nodes)
    15. # 将得到的表示插入到nodes_batch_layers中
    16. # 是采用向前插入的方式
    17. # 首先要是中心节点,然后是与中心节点连接的节点部分,然后是与连接节点连接的部分
    18. # 因此要把最外层的节点插入到最前面
    19. # layer1, layer0,layer_center
    20. nodes_batch_layers.insert(0, (lower_layer_nodes, lower_samp_neighs, lower_layer_nodes_dict))
    21. # 层数加1,表示最开始的节点
    22. assert len(nodes_batch_layers) == self.num_layers + 1
    23. # 所有节点特征赋予到变量中
    24. pre_hidden_embs = self.raw_features
    25. for index in range(1, self.num_layers+1):
    26. # 先选取nodes_batch_layers第0层对应的节点是哪些(layer0)
    27. # nodes_batch_layers = [layer1,layer0,layer_center]
    28. #
    29. nb = nodes_batch_layers[index][0]
    30. # 取nodes_batch_layers第1层对应的节点是哪些(layer1)
    31. pre_neighs = nodes_batch_layers[index-1]
    32. # self.dc.logger.info('aggregate_feats.')
    33. # 聚合邻居节点和其中心节点
    34. aggregate_feats = self.aggregate(nb, pre_hidden_embs, pre_neighs)
    35. # 取对应的sage_layer
    36. sage_layer = getattr(self, 'sage_layer'+str(index))
    37. if index > 1:
    38. # 第一层的batch节点没有进行转换,要进行一下转换
    39. nb = self._nodes_map(nb, pre_hidden_embs, pre_neighs)
    40. # self.dc.logger.info('sage_layer.')
    41. # 开始图层之间的聚合操作 输入中心节点特征和聚合之后的特征
    42. cur_hidden_embs = sage_layer(self_feats=pre_hidden_embs[nb],
    43. aggregate_feats=aggregate_feats)
    44. # 经过第一层graphsage后的表示为:[2157,128], 2157是节点数量
    45. pre_hidden_embs = cur_hidden_embs
    46. return pre_hidden_embs
    47. def _nodes_map(self, nodes, hidden_embs, neighs):
    48. #
    49. layer_nodes, samp_neighs, layer_nodes_dict = neighs
    50. assert len(samp_neighs) == len(nodes)
    51. index = [layer_nodes_dict[x] for x in nodes] # 记录上一层的节点编码
    52. return index
    53. def _get_unique_neighs_list(self, nodes, num_sample=10): # num_sample 采样数
    54. #nodes 1024个,导入每个batch的节点
    55. _set = set
    56. # adj_lists 是每一个点所连接的节点有哪些
    57. # to_neighs 得到每个节点所属邻居的列表
    58. to_neighs = [self.adj_lists[int(node)] for node in nodes] # len(nodes): 1024 节点邻居
    59. # num_sample 周围采样邻居的数量,而不是对所有邻居都进行采样
    60. if not num_sample is None: #首先对每一个节点的邻居集合neigh进行遍历,判断一下已有邻居数和采样数大小,多于采样数进行抽样
    61. _sample = random.sample
    62. # 遍历所有的to_neighs
    63. # 如果to_neigh长度大于num_sample,则对其进行采样,如果小于num_sample 则放入所有节点
    64. samp_neighs = [_set(_sample(to_neigh, num_sample)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
    65. else:
    66. samp_neighs = to_neighs
    67. samp_neighs = [samp_neigh | set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)] #聚合邻居节点信息时加上自己本身节点的信息
    68. # 得到涉及到的节点的数量
    69. _unique_nodes_list = list(set.union(*samp_neighs)) #把一个批次内的所有节点的邻居节点编号聚集在一块并去重
    70. # 得到一个重新的排列
    71. i = list(range(len(_unique_nodes_list)))
    72. # 得到字典编号,节点对应的编号
    73. unique_nodes = dict(list(zip(_unique_nodes_list, i))) #为所有的邻居节点建立一个索引映射
    74. # samp_neighs 所有邻居节点的集合;
    75. # unique_nodes 所有节点对应的字典
    76. # _unique_nodes_list 所有节点的列表
    77. return samp_neighs, unique_nodes, _unique_nodes_list
    78. def aggregate(self, nodes, pre_hidden_embs, pre_neighs, num_sample=10):
    79. # 取出最外层节点的信息,对应图中最外层的绿色节点
    80. unique_nodes_list, samp_neighs, unique_nodes = pre_neighs
    81. assert len(nodes) == len(samp_neighs)
    82. # 节点本身是否包含着邻居节点中,也就是黄色节点+绿色节点
    83. indicator = [(nodes[i] in samp_neighs[i]) for i in range(len(samp_neighs))]
    84. assert (False not in indicator)
    85. if not self.gcn:
    86. # 去掉要进行聚合表示的节点(也就是中心节点本身,图中对应为去掉黄色节点)
    87. samp_neighs = [(samp_neighs[i]-set([nodes[i]])) for i in range(len(samp_neighs))]
    88. # self.dc.logger.info('2')
    89. # 如果涉及到所有节点,则保留原矩阵;
    90. # 如果不涉及所有节点,保留部分矩阵
    91. if len(pre_hidden_embs) == len(unique_nodes):
    92. embed_matrix = pre_hidden_embs
    93. else:
    94. embed_matrix = pre_hidden_embs[torch.LongTensor(unique_nodes_list)]
    95. # self.dc.logger.info('3')
    96. # 初始化了全为0的邻接矩阵(有关系的节点构成的邻接矩阵,不是所有节点)
    97. # 本层节点数量,涉及到上层节点数量
    98. mask = torch.zeros(len(samp_neighs), len(unique_nodes))
    99. column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
    100. row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
    101. # 有连接的点的值赋予成1
    102. mask[row_indices, column_indices] = 1
    103. # self.dc.logger.info('4')
    104. if self.agg_func == 'MEAN':
    105. # 按行求和,保持和输入一个维度
    106. num_neigh = mask.sum(1, keepdim=True)
    107. # 归一化操作,按行
    108. mask = mask.div(num_neigh).to(embed_matrix.device)
    109. # 矩阵相乘,相当于聚合周围邻居节点的信息求和
    110. aggregate_feats = mask.mm(embed_matrix)
    111. elif self.agg_func == 'MAX':
    112. # print(mask)
    113. indexs = [x.nonzero() for x in mask==1]
    114. aggregate_feats = []
    115. # self.dc.logger.info('5')
    116. for feat in [embed_matrix[x.squeeze()] for x in indexs]:
    117. if len(feat.size()) == 1:
    118. aggregate_feats.append(feat.view(1, -1))
    119. else:
    120. aggregate_feats.append(torch.max(feat,0)[0].view(1, -1))
    121. aggregate_feats = torch.cat(aggregate_feats, 0)
    122. # self.dc.logger.info('6')
    123. # 返回聚合之后的特征
    124. return aggregate_feats

    5 定义SageLayer src/models.py

            定义__init__()

    1. class SageLayer(nn.Module):
    2. """
    3. Encodes a node's using 'convolutional' GraphSage approach
    4. """
    5. def __init__(self, input_size, out_size, gcn=False):
    6. super(SageLayer, self).__init__()
    7. self.input_size = input_size # 1433
    8. self.out_size = out_size # 128
    9. self.gcn = gcn
    10. # 初始化,连接操作,设置为2倍 torch.Size([128, 2866])
    11. # 定义要学习的参数
    12. self.weight = nn.Parameter(torch.FloatTensor(out_size, self.input_size if self.gcn else 2 * self.input_size))
    13. # 这里设置成2*input_size是因为这一层的表示将上一层用户节点的嵌入与其邻居节点的嵌入连接到了一起,如下图所示
    14. self.init_params() # 初始化参数
    15. def init_params(self):
    16. for param in self.parameters():
    17. nn.init.xavier_uniform_(param)

     6 定义分类器模型 src/models.py

            定义初始化函数__init__()

    1. class Classification(nn.Module):
    2. def __init__(self, emb_size, num_classes):
    3. super(Classification, self).__init__()
    4. #self.weight = nn.Parameter(torch.FloatTensor(emb_size, num_classes))
    5. self.layer = nn.Sequential(
    6. nn.Linear(emb_size, num_classes)
    7. #nn.ReLU()
    8. )
    9. self.init_params()
    10. def init_params(self):
    11. for param in self.parameters():
    12. if len(param.size()) == 2:
    13. nn.init.xavier_uniform_(param)

            定义forward();

    1. def forward(self, self_feats, aggregate_feats, neighs=None):
    2. """
    3. Generates embeddings for a batch of nodes.
    4. nodes -- list of nodes
    5. """
    6. if not self.gcn:
    7. # 节点特征及节点周围邻居特征连接在一起
    8. combined = torch.cat([self_feats, aggregate_feats], dim=1)
    9. else:
    10. combined = aggregate_feats
    11. # 与可学习权重相乘
    12. combined = F.relu(self.weight.mm(combined.t())).t()
    13. return combined

    7 模型运行函数 src/utils.py

    1. def apply_model(dataCenter, ds, graphSage, classification, unsupervised_loss, b_sz, unsup_loss, device, learn_method):
    2. # 验证集、测试集、训练集节点特征及label
    3. test_nodes = getattr(dataCenter, ds+'_test')
    4. val_nodes = getattr(dataCenter, ds+'_val')
    5. train_nodes = getattr(dataCenter, ds+'_train')
    6. labels = getattr(dataCenter, ds+'_labels')
    7. # 无监督的loss, 这里采用有监督,因此该定义无影响
    8. if unsup_loss == 'margin':
    9. num_neg = 6
    10. elif unsup_loss == 'normal':
    11. num_neg = 100
    12. else:
    13. print("unsup_loss can be only 'margin' or 'normal'.")
    14. sys.exit(1)
    15. # 打乱训练集
    16. train_nodes = shuffle(train_nodes)
    17. # 定义模型
    18. models = [graphSage, classification]
    19. # 定义模型参数
    20. params = []
    21. # 循环模型
    22. for model in models:
    23. # 遍历模型所有参数
    24. for param in model.parameters():
    25. # 参数定义为可训练的梯度
    26. if param.requires_grad:
    27. params.append(param) # W和bias
    28. optimizer = torch.optim.SGD(params, lr=0.7)
    29. optimizer.zero_grad()
    30. # 初始化模型梯度
    31. for model in models:
    32. model.zero_grad()
    33. # 每一轮训练迭代数
    34. # b_sz定义为20
    35. batches = math.ceil(len(train_nodes) / b_sz)
    36. visited_nodes = set()
    37. # 遍历每一个batch
    38. for index in range(batches):
    39. # batch内的节点
    40. nodes_batch = train_nodes[index*b_sz:(index+1)*b_sz]
    41. # extend nodes batch for unspervised learning
    42. # no conflicts with supervised learning
    43. # 对于无监督,在无监督上进行了负采样的操作
    44. # 对于有监督命令的执行是不冲突的 ,只是训练节点的数量增加了, 这里是1024个节点
    45. nodes_batch = np.asarray(list(unsupervised_loss.extend_nodes(nodes_batch, num_neg=num_neg)))
    46. visited_nodes |= set(nodes_batch)
    47. # get ground-truth for the nodes batch
    48. # 得到节点的label
    49. labels_batch = labels[nodes_batch]
    50. # feed nodes batch to the graphSAGE
    51. # returning the nodes embeddings
    52. # 跳入到graphsage层中,学习到节点表征
    53. # [1024,128]
    54. embs_batch = graphSage(nodes_batch)
    55. if learn_method == 'sup':
    56. # superivsed learning
    57. # 得到[1024,7]
    58. logists = classification(embs_batch)
    59. loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
    60. loss_sup /= len(nodes_batch)
    61. loss = loss_sup
    62. elif learn_method == 'plus_unsup':
    63. # superivsed learning
    64. logists = classification(embs_batch)
    65. loss_sup = -torch.sum(logists[range(logists.size(0)), labels_batch], 0)
    66. loss_sup /= len(nodes_batch)
    67. # unsuperivsed learning
    68. if unsup_loss == 'margin':
    69. loss_net = unsupervised_loss.get_loss_margin(embs_batch, nodes_batch)
    70. elif unsup_loss == 'normal':
    71. loss_net = unsupervised_loss.get_loss_sage(embs_batch, nodes_batch)
    72. loss = loss_sup + loss_net
    73. else:
    74. if unsup_loss == 'margin':
    75. loss_net = unsupervised_loss.get_loss_margin(embs_batch, nodes_batch)
    76. elif unsup_loss == 'normal':
    77. loss_net = unsupervised_loss.get_loss_sage(embs_batch, nodes_batch)
    78. loss = loss_net
    79. print('Step [{}/{}], Loss: {:.4f}, Dealed Nodes [{}/{}] '.format(index+1, batches, loss.item(), len(visited_nodes), len(train_nodes)))
    80. loss.backward()
    81. for model in models:
    82. nn.utils.clip_grad_norm_(model.parameters(), 5)
    83. optimizer.step()
    84. optimizer.zero_grad()
    85. for model in models:
    86. model.zero_grad()
    87. return graphSage, classification

    想要更详细,看B站的视频,源码+调试的讲,很清楚!!

    16. 4.3_GraphSAGE代码_哔哩哔哩_bilibili

  • 相关阅读:
    搭建一个属于自己的博客
    使用vuedraggable实现拖拽式操作实战
    python爬虫进阶篇(异步)
    【数据库系统概论】SQL是什么?它有什么特点?
    GBase 8c 创建和管理表(二)
    李宏毅2021《机器学习/深度学习》——学习笔记(4)
    基于SpringBoot的商品物品产品众筹平台设计与实现(源码+lw+部署文档+讲解等)
    java小程序商城免费搭建 VR全景商城 saas商城 b2b2c商城 o2o商城 积分商城 秒杀商城 拼团商城 分销商城 短视频商城
    C语言的预处理命令
    vite + vu3 + ts 项目,npm run build 报错
  • 原文地址:https://blog.csdn.net/qq_40671063/article/details/126757290