• 【GNN】用 GCN 预测 CoraGraphDataset 结点类别


    Cora Dataset 是 DGL 自带的一个论文数据集。

    word_attributes 是一个维度为 1433词向量,词向量的每个元素对应一个词,0表示该元素对应的词不在Paper中,1表示该元素对应的词在Paper中。

    class_label 是论文的类别,每篇 Paper 被映射到如下 7 个分类之一:
    Case_Based、Genetic_Algorithms、Neural_Networks、Probabilistic_Methods、Reinforcement_Learning、Rule_Learning、Theory。

    目标:利用 GCN,在数据集上进行训练,根据 图结构图结点的特征 g.ndata['feat'] 特征字段(1433 维词向量),对 label = g.ndata['label']代表论文所属类别)进行预测。

    查看数据集

    from dgl.data import CoraGraphDataset
    dataset = CoraGraphDataset()
    
    • 1
    • 2

    NumNodes: 2708 结点数量
    NumEdges: 10556 边数量
    NumFeats: 1433 特征维度
    NumClasses: 7 类别个数
    NumTrainingSamples: 140
    NumValidationSamples: 500
    NumTestSamples: 1000

    G = dataset[0]  # dgl的获取全图的方式
    G
    
    • 1
    • 2
    Graph(num_nodes=2708, num_edges=10556,
          ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(1433,), dtype=torch.float32)}
          edata_schemes={})
    
    • 1
    • 2
    • 3

    可视化

    import networkx as nx
    
    nx_G = G.to_networkx().to_undirected() # 转无向图
    
    # pos = nx.kamada_kawai_layout(nx_G)       # 适合小图的布局
    pos = nx.fruchterman_reingold_layout(nx_G) # 适合于大图的布局(电子布局)
    
    nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    在这里插入图片描述

    GCN 模型训练

    GCN.py 文件

    import torch
    import torch.nn as nn
    from dgl.nn.pytorch import GraphConv
    
    class GCN(nn.Module):
        def __init__(self,
                     g,
                     in_feats,
                     n_hidden,
                     n_classes,
                     n_layers,
                     activation,
                     dropout):
            super(GCN, self).__init__()
            self.g = g
            self.layers = nn.ModuleList()
            # input layer
            
            self.layers.append(GraphConv(in_feats, n_hidden, activation=activation))
            # hidden layers
            for i in range(n_layers - 1):
                self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
                
            # output layer
            self.layers.append(GraphConv(n_hidden, n_classes))
            
            self.dropout = nn.Dropout(p=dropout)
    
    
        # 前向传播 ###############################
    
        def forward(self, features):
            h = features
            for i, layer in enumerate(self.layers):
                if i != 0:
                    h = self.dropout(h) # 每经过一个卷积层都dropout
                                        # (除了第一个不drop)
                h = layer(self.g, h)
            return h
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    main.py

    import argparse
    import time
    import numpy as np
    import torch
    import torch.nn.functional as F
    import dgl
    from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
    
    from gcn import GCN          # 不使用消息传递
    # from gcn_mp import GCN      # 使用消息传递
    
    def evaluate(model, features, labels, mask):
        model.eval()
        with torch.no_grad():
            logits = model(features)
            logits = logits[mask]
            labels = labels[mask]
            _, indices = torch.max(logits, dim=1)
            correct = torch.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)
    
    if __name__ == '__main__':
       
        # 获得图数据 ###################################
    
        data = CoraGraphDataset()
        g = data[0]
        cuda = False
    
        # 获得数据集的一些特征 ##########################
    
        features = g.ndata['feat']
        labels = g.ndata['label']
        train_mask = g.ndata['train_mask']
        val_mask = g.ndata['val_mask']
        test_mask = g.ndata['test_mask']
        in_feats = features.shape[1]
        n_classes = data.num_labels
        n_edges = data.graph.number_of_edges()
    
        n_edges = g.number_of_edges()
    
        # 正则化 #######################################
        
        degs = g.in_degrees().float()
        norm = torch.pow(degs, -0.5)      #   1/\sqrt(degreee) 
        norm[torch.isinf(norm)] = 0
    
        g.ndata['norm'] = norm.unsqueeze(1) # 升一维,变成二维torch.Tensor
    
        ################## 设置网络参数,准备建立网络 ####################
          
        n_hidden = 16
        n_layers = 1
        dropout = 0.5         # 最后一层Graph Conv的dropout概率
        lr = 1e-2
        weight_decay = 5e-4   # L2正则化权重 (Weight for L2 loss)
        n_epochs = 200
    
        model = GCN(g,
                    in_feats,
                    n_hidden,
                    n_classes,
                    n_layers,
                    F.relu,
                    dropout)
    
        loss_fcn = torch.nn.CrossEntropyLoss()
    
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    
        ################ 准备开始训练 ###################################
    
        # initialize graph
        dur = []
        for epoch in range(n_epochs):
            model.train()
            if epoch >= 3:
                t0 = time.time()
                
                
            # 前向计算
            logits = model(features) # logits就是最终的全连接层的输出
            
            # 计算loss
            # 因为数据集本身就是一个大图,所以没办法像传统的方法那样
            # 划分数据集,所以只能以mask列表(里面是True/False)
            # 的形式,来取得对应的预测值和真实标签list的一部分
        
            loss = loss_fcn(logits[train_mask], labels[train_mask])
    
            # 经典三连
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            # 训练时间计时
    
            if epoch >= 3:
                dur.append(time.time() - t0)
    
            # 在测试集上评估准确率 acc
        
            acc = evaluate(model, features, labels, val_mask)
            print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                  "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
                                                 acc, n_edges / np.mean(dur) / 1000))
    
        # 在训练集上评估准确率 acc
        
        print()
        acc = evaluate(model, features, labels, test_mask)
        print("Test accuracy {:.2%}".format(acc))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    ...
    Epoch 00198 | Time(s) 0.0199 | Loss 0.3593 | Accuracy 0.7860 | ETputs(KTEPS) 529.47
    Epoch 00199 | Time(s) 0.0199 | Loss 0.3676 | Accuracy 0.7840 | ETputs(KTEPS) 529.60
    
    Test accuracy 80.00%
    
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    Windows 程序安装与更新方案: Clowd.Squirrel
    深度学习使用Keras进行迁移学习提升网络性能
    【U-Boot笔记整理】U-Boot 完全分析与移植
    计算摄影——图像美学评分
    《web课程设计》用HTML CSS做一个简洁、漂亮的个人博客网站
    Python编程从入门到实践 第二章:变量和简单数据类型 练习答案记录
    java技术:knife4j实现后端swagger文档
    Vue框架--Vue中el和data的两种写法
    fiddler 手机抓包
    常用正则表达式
  • 原文地址:https://blog.csdn.net/qq_18846849/article/details/127789314