• 使用PyG (PyTorch Geometric) 实现同质图transductive链路预测任务


    诸神缄默不语-个人CSDN博文目录
    PyTorch Geometric (PyG) 包文档与官方代码示例学习笔记(持续更新ing…)

    本文代码参考自PyG官方示例代码:https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py

    1. 数据获取

    本文直接调用PyG官方的Cora数据集,如果环境可以直接登外网的话,其实可以直接运行后续模型。如果不能的话,可以参考我之前撰写的博文手动下载对应数据:PyG的Planetoid无法直接下载Cora等数据集的3个解决方式

    2. 数据预处理

    这里的处理方式是直接在载入数据时,就直接调用PyG的类:

    1. 对节点特征进行行归一化(T.NormalizeFeatures(),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.NormalizeFeatures,源码torch_geometric.transforms.normalize_features — pytorch_geometric documentation):使每一行总和为1、且更稀疏,具体做法是:元素减去最小值,然后除以总值(设置最小值为1)
    2. 将DataSet对象放到GPU上(T.ToDevice(device),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.ToDevice
    3. 对DataSet对象用链路预测的方法进行数据集划分:
      T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False)
      文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RandomLinkSplit
      训练集中不包含验证集和测试集的边,验证集中不包含测试集的边。注意本代码是transductive的,所以划分得到的3个数据集的节点都是一样的。
      返回的DataSet对象中的元素是tuple,每个tuple包含3个元素(train_data/val_data/test_data),每个元素都是Data对象。
    import torch
    
    import torch_geometric.transforms as T
    from torch_geometric.datasets import Planetoid
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    transform = T.Compose([
        T.NormalizeFeatures(),
        T.ToDevice(device),
        T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                          add_negative_train_samples=False),
    ])
    dataset = Planetoid('pyg_data/Planetoid', name='Cora', transform=transform)
    print(type(dataset))
    train_data, val_data, test_data = dataset[0]
    print(type(train_data))
    print(train_data.num_nodes)
    print(val_data.num_nodes)
    print(test_data.num_nodes)
    

    输出:

    
    
    2708
    2708
    2708
    

    3. 建立链路预测模型

    1. encode()函数:GNN节点表征,使用2层GCN,其中用了ReLU激活函数。没有其他trick。
    2. decode()函数在训练时使用,仅计算指定edge_label_index上的边,在代码上用逐元素乘法的加总表示点积。
    3. decode_all()函数在测试时使用,计算整张图所有节点对存在边的概率,也是用矩阵乘法来实现点积,结果的概率大于0直接认为节点对之间存在边,返回的是这个被认为存在边的edge list。
    import torch
    from torch_geometric.nn import GCNConv
    
    class Net(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
    
        def encode(self, x, edge_index):
            x = self.conv1(x, edge_index).relu()
            return self.conv2(x, edge_index)
    
        def decode(self, z, edge_label_index):
            return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
    
        def decode_all(self, z):
            prob_adj = z @ z.t()
            return (prob_adj > 0).nonzero(as_tuple=False).t()
    

    4. 实例化模型,设置优化器、损失函数

    链路预测一般被建模为二分类任务(即边是否存在,因此常用torch.nn.BCEWithLogitsLoss()

    model = Net(dataset.num_features, 128, 64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()
    

    5. 构建训练函数

    每个epoch调用一次训练函数。
    在训练集上,首先用GNN实现节点表征,然后调用negative_sampling(文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.negative_sampling)抽样负边(与正边数量一样),计算对应的损失函数。

    from torch_geometric.utils import negative_sampling
    
    def train():
        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data.x, train_data.edge_index)
    
        # We perform a new round of negative sampling for every training epoch:
        neg_edge_index = negative_sampling(
            edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    
        edge_label_index = torch.cat(
            [train_data.edge_label_index, neg_edge_index],
            dim=-1,
        )
        edge_label = torch.cat([
            train_data.edge_label,
            train_data.edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)
    
        out = model.decode(z, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()
        return loss
    

    6. 构建每个epoch运行时的测试函数

    其实我个人比较喜欢用with torch.no_grad()

    每个epoch调用一次。
    计算图数据上正边的概率,直接用其通过Sigmoid激活函数后的结果作为边存在的概率,用以计算ROC AUC值。

    @torch.no_grad()
    def test(data):
        model.eval()
        z = model.encode(data.x, data.edge_index)
        out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
        return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
    

    7. 训练和测试

    训练100个epoch,最后得到测试集上所有模型认为存在的边。

    best_val_auc = final_test_auc = 0
    for epoch in range(1, 101):
        loss = train()
        val_auc = test(val_data)
        test_auc = test(test_data)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
              f'Test: {test_auc:.4f}')
    
    print(f'Final Test: {final_test_auc:.4f}')
    
    z = model.encode(test_data.x, test_data.edge_index)
    final_edge_index = model.decode_all(z)
    

    输出:

    Epoch: 001, Loss: 0.6930, Val: 0.6729, Test: 0.7026
    Epoch: 002, Loss: 0.6820, Val: 0.6589, Test: 0.6913
    Epoch: 003, Loss: 0.7065, Val: 0.6619, Test: 0.6967
    Epoch: 004, Loss: 0.6766, Val: 0.6686, Test: 0.7069
    Epoch: 005, Loss: 0.6842, Val: 0.6716, Test: 0.7128
    Epoch: 006, Loss: 0.6876, Val: 0.6637, Test: 0.7132
    Epoch: 007, Loss: 0.6881, Val: 0.6471, Test: 0.7009
    Epoch: 008, Loss: 0.6867, Val: 0.6317, Test: 0.6859
    Epoch: 009, Loss: 0.6829, Val: 0.6240, Test: 0.6767
    Epoch: 010, Loss: 0.6765, Val: 0.6223, Test: 0.6720
    Epoch: 011, Loss: 0.6715, Val: 0.6208, Test: 0.6684
    Epoch: 012, Loss: 0.6759, Val: 0.6204, Test: 0.6640
    Epoch: 013, Loss: 0.6687, Val: 0.6272, Test: 0.6656
    Epoch: 014, Loss: 0.6621, Val: 0.6488, Test: 0.6778
    Epoch: 015, Loss: 0.6593, Val: 0.6748, Test: 0.6907
    Epoch: 016, Loss: 0.6534, Val: 0.6824, Test: 0.6923
    Epoch: 017, Loss: 0.6477, Val: 0.6796, Test: 0.6867
    Epoch: 018, Loss: 0.6389, Val: 0.6847, Test: 0.6888
    Epoch: 019, Loss: 0.6332, Val: 0.7155, Test: 0.7115
    Epoch: 020, Loss: 0.6217, Val: 0.7487, Test: 0.7430
    Epoch: 021, Loss: 0.6060, Val: 0.7645, Test: 0.7582
    Epoch: 022, Loss: 0.5993, Val: 0.7650, Test: 0.7574
    Epoch: 023, Loss: 0.5837, Val: 0.7632, Test: 0.7550
    Epoch: 024, Loss: 0.5719, Val: 0.7612, Test: 0.7530
    Epoch: 025, Loss: 0.5654, Val: 0.7565, Test: 0.7518
    Epoch: 026, Loss: 0.5697, Val: 0.7574, Test: 0.7534
    Epoch: 027, Loss: 0.5676, Val: 0.7610, Test: 0.7576
    Epoch: 028, Loss: 0.5551, Val: 0.7629, Test: 0.7634
    Epoch: 029, Loss: 0.5446, Val: 0.7682, Test: 0.7723
    Epoch: 030, Loss: 0.5422, Val: 0.7774, Test: 0.7848
    Epoch: 031, Loss: 0.5259, Val: 0.7896, Test: 0.7988
    Epoch: 032, Loss: 0.5277, Val: 0.8005, Test: 0.8127
    Epoch: 033, Loss: 0.5218, Val: 0.8135, Test: 0.8245
    Epoch: 034, Loss: 0.5156, Val: 0.8234, Test: 0.8342
    Epoch: 035, Loss: 0.5057, Val: 0.8285, Test: 0.8414
    Epoch: 036, Loss: 0.4981, Val: 0.8314, Test: 0.8462
    Epoch: 037, Loss: 0.4984, Val: 0.8302, Test: 0.8459
    Epoch: 038, Loss: 0.4960, Val: 0.8332, Test: 0.8489
    Epoch: 039, Loss: 0.4873, Val: 0.8381, Test: 0.8555
    Epoch: 040, Loss: 0.4883, Val: 0.8418, Test: 0.8609
    Epoch: 041, Loss: 0.4993, Val: 0.8427, Test: 0.8615
    Epoch: 042, Loss: 0.4852, Val: 0.8452, Test: 0.8616
    Epoch: 043, Loss: 0.4718, Val: 0.8474, Test: 0.8640
    Epoch: 044, Loss: 0.4768, Val: 0.8492, Test: 0.8679
    Epoch: 045, Loss: 0.4708, Val: 0.8472, Test: 0.8688
    Epoch: 046, Loss: 0.4726, Val: 0.8457, Test: 0.8680
    Epoch: 047, Loss: 0.4729, Val: 0.8500, Test: 0.8698
    Epoch: 048, Loss: 0.4726, Val: 0.8517, Test: 0.8705
    Epoch: 049, Loss: 0.4730, Val: 0.8527, Test: 0.8722
    Epoch: 050, Loss: 0.4715, Val: 0.8521, Test: 0.8734
    Epoch: 051, Loss: 0.4667, Val: 0.8547, Test: 0.8756
    Epoch: 052, Loss: 0.4609, Val: 0.8577, Test: 0.8784
    Epoch: 053, Loss: 0.4632, Val: 0.8607, Test: 0.8829
    Epoch: 054, Loss: 0.4612, Val: 0.8626, Test: 0.8862
    Epoch: 055, Loss: 0.4591, Val: 0.8646, Test: 0.8878
    Epoch: 056, Loss: 0.4568, Val: 0.8644, Test: 0.8874
    Epoch: 057, Loss: 0.4569, Val: 0.8656, Test: 0.8874
    Epoch: 058, Loss: 0.4568, Val: 0.8688, Test: 0.8897
    Epoch: 059, Loss: 0.4516, Val: 0.8721, Test: 0.8929
    Epoch: 060, Loss: 0.4567, Val: 0.8729, Test: 0.8942
    Epoch: 061, Loss: 0.4625, Val: 0.8742, Test: 0.8938
    Epoch: 062, Loss: 0.4547, Val: 0.8729, Test: 0.8919
    Epoch: 063, Loss: 0.4479, Val: 0.8723, Test: 0.8927
    Epoch: 064, Loss: 0.4517, Val: 0.8728, Test: 0.8962
    Epoch: 065, Loss: 0.4517, Val: 0.8719, Test: 0.8972
    Epoch: 066, Loss: 0.4538, Val: 0.8726, Test: 0.8962
    Epoch: 067, Loss: 0.4532, Val: 0.8718, Test: 0.8944
    Epoch: 068, Loss: 0.4540, Val: 0.8725, Test: 0.8937
    Epoch: 069, Loss: 0.4542, Val: 0.8734, Test: 0.8953
    Epoch: 070, Loss: 0.4487, Val: 0.8726, Test: 0.8967
    Epoch: 071, Loss: 0.4497, Val: 0.8727, Test: 0.8973
    Epoch: 072, Loss: 0.4539, Val: 0.8694, Test: 0.8949
    Epoch: 073, Loss: 0.4478, Val: 0.8703, Test: 0.8937
    Epoch: 074, Loss: 0.4449, Val: 0.8737, Test: 0.8945
    Epoch: 075, Loss: 0.4486, Val: 0.8770, Test: 0.8968
    Epoch: 076, Loss: 0.4491, Val: 0.8724, Test: 0.8970
    Epoch: 077, Loss: 0.4431, Val: 0.8678, Test: 0.8957
    Epoch: 078, Loss: 0.4447, Val: 0.8688, Test: 0.8952
    Epoch: 079, Loss: 0.4540, Val: 0.8704, Test: 0.8943
    Epoch: 080, Loss: 0.4548, Val: 0.8741, Test: 0.8955
    Epoch: 081, Loss: 0.4468, Val: 0.8746, Test: 0.8985
    Epoch: 082, Loss: 0.4495, Val: 0.8727, Test: 0.8994
    Epoch: 083, Loss: 0.4473, Val: 0.8708, Test: 0.8990
    Epoch: 084, Loss: 0.4464, Val: 0.8715, Test: 0.8976
    Epoch: 085, Loss: 0.4376, Val: 0.8755, Test: 0.8977
    Epoch: 086, Loss: 0.4455, Val: 0.8762, Test: 0.8993
    Epoch: 087, Loss: 0.4442, Val: 0.8727, Test: 0.9004
    Epoch: 088, Loss: 0.4411, Val: 0.8726, Test: 0.9009
    Epoch: 089, Loss: 0.4445, Val: 0.8760, Test: 0.9010
    Epoch: 090, Loss: 0.4474, Val: 0.8780, Test: 0.9002
    Epoch: 091, Loss: 0.4468, Val: 0.8754, Test: 0.9009
    Epoch: 092, Loss: 0.4470, Val: 0.8712, Test: 0.9015
    Epoch: 093, Loss: 0.4467, Val: 0.8680, Test: 0.9006
    Epoch: 094, Loss: 0.4454, Val: 0.8720, Test: 0.9019
    Epoch: 095, Loss: 0.4355, Val: 0.8761, Test: 0.9028
    Epoch: 096, Loss: 0.4486, Val: 0.8749, Test: 0.9013
    Epoch: 097, Loss: 0.4418, Val: 0.8695, Test: 0.8999
    Epoch: 098, Loss: 0.4396, Val: 0.8651, Test: 0.9002
    Epoch: 099, Loss: 0.4365, Val: 0.8684, Test: 0.9034
    Epoch: 100, Loss: 0.4428, Val: 0.8720, Test: 0.9050
    Final Test: 0.9002
    torch.Size([2, 3262820])
    

    8. 整体代码

    import torch
    from sklearn.metrics import roc_auc_score
    
    import torch_geometric.transforms as T
    from torch_geometric.datasets import Planetoid
    from torch_geometric.nn import GCNConv
    from torch_geometric.utils import negative_sampling
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    transform = T.Compose([
        T.NormalizeFeatures(),
        T.ToDevice(device),
        T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                          add_negative_train_samples=False),
    ])
    dataset = Planetoid('/data/pyg_data/Planetoid', name='Cora', transform=transform)
    train_data, val_data, test_data = dataset[0]
    
    
    class Net(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
    
        def encode(self, x, edge_index):
            x = self.conv1(x, edge_index).relu()
            return self.conv2(x, edge_index)
    
        def decode(self, z, edge_label_index):
            return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
    
        def decode_all(self, z):
            prob_adj = z @ z.t()
            return (prob_adj > 0).nonzero(as_tuple=False).t()
    
    
    model = Net(dataset.num_features, 128, 64).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    
    def train():
        model.train()
        optimizer.zero_grad()
        z = model.encode(train_data.x, train_data.edge_index)
    
        # We perform a new round of negative sampling for every training epoch:
        neg_edge_index = negative_sampling(
            edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
            num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    
        edge_label_index = torch.cat(
            [train_data.edge_label_index, neg_edge_index],
            dim=-1,
        )
        edge_label = torch.cat([
            train_data.edge_label,
            train_data.edge_label.new_zeros(neg_edge_index.size(1))
        ], dim=0)
    
        out = model.decode(z, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()
        return loss
    
    
    @torch.no_grad()
    def test(data):
        model.eval()
        z = model.encode(data.x, data.edge_index)
        out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
        return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
    
    
    best_val_auc = final_test_auc = 0
    for epoch in range(1, 101):
        loss = train()
        val_auc = test(val_data)
        test_auc = test(test_data)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
              f'Test: {test_auc:.4f}')
    
    print(f'Final Test: {final_test_auc:.4f}')
    
    z = model.encode(test_data.x, test_data.edge_index)
    final_edge_index = model.decode_all(z)
    
    print(final_edge_index.size())
    
  • 相关阅读:
    Element登录+注册
    【CocosCreator】利用遮罩Mask实现单边开门效果
    Downie 4下载画质的设置方法,downie 4设置下载清晰度
    IOS开发学习日记(十七)
    如何在.Net Framework应用中请求HTTP2站点
    营收下滑,腾讯游戏还能保持「王者」地位吗?
    AndroidStudio怎么查看Kotlion Bytecode,也就是查看kotlion的字节码
    HDLBits-Edgedetect
    python基础练习题库实验1
    Redis——分布式缓存
  • 原文地址:https://blog.csdn.net/PolarisRisingWar/article/details/126939852