• 【PyG】理解MessagePassing过程,GCN demo详解


    参考:
    PyG利用MessagePassing实现GCN(了解pyG的底层逻辑)
    PyG官方demo GCN

    PyG的信息传递机制

    PyG提供了信息传递(邻居聚合) 操作的框架模型。

    x i k = γ k ( x i k − 1 , □ j ∈ N ( i ) ϕ ( x i k − 1 , x j k − 1 , e j , i ) ) x_i^k = \gamma^k(x_i^{k-1}, \square_{j \in \mathcal{N}(i)} \phi(x_i^{k-1},x_j^{k-1},e_{j,i})) xik=γk(xik1,jN(i)ϕ(xik1,xjk1,ej,i))
    其中,
    □ \square 表示 可微、排列不变 的函数,比如说summeanmax
    γ \gamma γ ϕ \phi ϕ 表示 可微 的函数,比如说 MLP

    propagate中,依次会调用messageaggregateupdate函数。
    其中,
    message为公式中 ϕ \phi ϕ 部分
    aggregate为公式中 □ \square 部分
    update为公式中 γ \gamma γ 部分

    MessagePassing Class

    PyG使用MessagePassing类作为实现 信息传递 机制的基类。我们只需要继承其即可。

    GCN demo

    GCN信息传递公式如下:
    x i k = ∑ j ∈ i ∪ { i } 1 d e g ( i ) ⋅ d e g ( j ) ⋅ ( Θ T ⋅ x j k − 1 ) x_i^k = \sum_{j \in \mathcal{i} \cup \{i\}} {1 \over \sqrt{\mathrm{deg}(i)} \cdot \sqrt{\mathrm{deg}(j)} } \cdot (\Theta^T \cdot x_j^{k-1}) xik=ji{i}deg(i) deg(j) 1(ΘTxjk1)

    注:GCN是运行在 无向图 上的。

    1. 导入头文件

    from typing import Optional
    from torch_scatter import scatter
    import torch
    import numpy as np
    import random
    import os
    from torch import Tensor
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2. 构造函数

    class GCNConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super().__init__(aggr='add')  # "Add" aggregation (Step 5).
            self.lin = torch.nn.Linear(in_channels, out_channels)
    
    • 1
    • 2
    • 3
    • 4

    定义类GCNConv继承MessagePassing

    aggr定义了聚合函数的作用。这里add表示累加。
    当然,我们也可以通过重写aggregate方法,来自定义 聚合函数

    定义了线性变换层lin,也就是公式中的 Θ \Theta Θ。不过,与公式不同的是,这里的lin是有偏置bias的。

    3. 前向传播forward

        def forward(self, x, edge_index):
            # x.shape == [N, in_channels]
            # edge_index.shape == [2, E]
    
            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
            # Step 2: Linearly transform node feature matrix.
            x = self.lin(x) # x = lin(x)
    
            # Step 3: Compute normalization.
            row, col = edge_index # row, col is the [out index] and [in index]
            deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]
    
            # Step 4-5: Start propagating messages.
            return self.propagate(edge_index, x=x, norm=norm)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    定义 神经网络的 前向传播 过程。

    # Step 1: Add self-loops to the adjacency matrix.
    edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
    • 1
    • 2

    添加自环。

    # Step 2: Linearly transform node feature matrix.
    x = self.lin(x) # x = lin(x)
    
    • 1
    • 2

    计算 Θ ⋅ x \Theta \cdot x Θx

    # Step 3: Compute normalization.
    row, col = edge_index # row, col is the [out index] and [in index]
    deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    计算 系数,也就是公式中的
    1 d e g ( i ) ⋅ d e g ( j ) {1 \over \sqrt{\mathrm{deg}(i)} \cdot \sqrt{\mathrm{deg}(j)} } deg(i) deg(j) 1

    这里有点难理解。可以根据 张量的形状 进行理解。

    row表示出边的顶点,col表示入边的顶点。

    注:PyG是支持有向图的,所以(0,1), (1,0)一起表示无向图中的一条边。

    degree计算 入顶点的度数。但,由于GCN运行在无向图上,其实 入顶点个数 == 顶点个数

    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 把度数为0的节点去掉,因为他们是无穷大。

    最后结果得到的norm 表示的含义是,边上两个节点度数乘积。即,每条边表示 1 d e g ( i ) ⋅ d e g ( j ) {1 \over \sqrt{\mathrm{deg}(i)} \cdot \sqrt{\mathrm{deg}(j)} } deg(i) deg(j) 1 一个权重系数。

    4. message

        def message(self, x_i, x_j, norm):
            # x_j ::= x[edge_index[0]] shape = [E, out_channels]
            # x_i ::= x[edge_index[1]] shape = [E, out_channels]
            # norm.view(-1, 1).shape = [E, 1]
            # Step 4: Normalize node features.
            return norm.view(-1, 1) * x_j
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    定义 信息传递函数。

    有同学会问,x_i, x_j哪里来的?
    PyG为我们提供的。
    其中,MessagePassing默认信息流向flowsource_to_target。若存在边(0,1),那么 信息流向0->1
    x_j就是source点,x_i就是target点。

    norm.view(-1, 1) * x_j,将 边上的权重 乘上 source点的特征。即完成了 1 d e g ( i ) ⋅ d e g ( j ) ⋅ ( Θ T ⋅ x j k − 1 ) {1 \over \sqrt{\mathrm{deg}(i)} \cdot \sqrt{\mathrm{deg}(j)} } \cdot (\Theta^T \cdot x_j^{k-1}) deg(i) deg(j) 1(ΘTxjk1)

    5. aggregate

        def aggregate(self, inputs: Tensor, index: Tensor,
                      ptr: Optional[Tensor] = None,
                      dim_size: Optional[int] = None) -> Tensor:
            # 第一个参数不能变化
            # index ::= edge_index[1]
            # dim_size ::= [number of node]
            # Step 5: Aggregate the messages.
            # out.shape = [number of node, out_channels]
            out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)
            return out
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    定义 聚合函数。
    其实,到这步 我们可以不用写了,因为之前的aggr="add"就已经足够了。

    index参数 由 PyG提供,为 入顶点的编号。
    torch_scatter.scatter函数 简单的说,就是把 编号相同 的属性[累加、求最大、求最小]聚集在一起

    下面这张图为,scatter求最大。

    详见:
    pytorch:torch_scatter.scatter_max
    torch.scatter与torch_scatter库使用整理

    6. update

        def update(self, inputs: Tensor, x_i, x_j) -> Tensor:
            # 第一个参数不能变化
            # inputs ::= aggregate.out
            # Step 6: Return new node embeddings.
            return inputs
    
    • 1
    • 2
    • 3
    • 4
    • 5

    使用得到的 信息,更新当前节点的信息。

    inputs为 更新得到的信息,其实就是 aggregate的输出。

    update 对应了 公式中的 γ \gamma γ

    注意:第一个参数 为aggregate的输出。可改名字,但不能换位置。

    完整GCN demo代码

    from typing import Optional
    from torch_scatter import scatter
    import torch
    import numpy as np
    import random
    import os
    from torch import Tensor
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree
    
    class GCNConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super().__init__(aggr='add')  # "Add" aggregation (Step 5).
            self.lin = torch.nn.Linear(in_channels, out_channels)
    
        def forward(self, x, edge_index):
            # x has shape [N, in_channels]
            # edge_index has shape [2, E]
    
            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
    
            # Step 2: Linearly transform node feature matrix.
            x = self.lin(x) # x = lin(x)
    
            # Step 3: Compute normalization.
            row, col = edge_index # row, col is the [out index] and [in index]
            deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]
    
            # Step 4-6: Start propagating messages.
            return self.propagate(edge_index, x=x, norm=norm)
    
        def message(self, x_i, x_j, norm):
            # x_j ::= x[edge_index[0]] shape = [E, out_channels]
            # x_i ::= x[edge_index[1]] shape = [E, out_channels]
            print("x_j", x_j.shape, x_j)
            print("x_i: ", x_i.shape, x_i)
            # norm.view(-1, 1).shape = [E, 1]
            # Step 4: Normalize node features.
            return norm.view(-1, 1) * x_j
    
        def aggregate(self, inputs: Tensor, index: Tensor,
                      ptr: Optional[Tensor] = None,
                      dim_size: Optional[int] = None) -> Tensor:
            # 第一个参数不能变化
            # index ::= edge_index[1]
            # dim_size ::= [number of node]
            print("agg_index: ",index)
            print("agg_dim_size: ",dim_size)
            # Step 5: Aggregate the messages.
            # out.shape = [number of node, out_channels]
            out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)
            print("agg_out:",out.shape,out)
            return out
        
        def update(self, inputs: Tensor, x_i, x_j) -> Tensor:
            # 第一个参数不能变化
            # inputs ::= aggregate.out
            # Step 6: Return new node embeddings.
            print("update_x_i: ",x_i.shape,x_i)
            print("update_x_j: ",x_j.shape,x_j)
            print("update_inputs: ",inputs.shape, inputs)
            return inputs
    
    def set_seed(seed=1029):
    	random.seed(seed)
    	os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现
    	np.random.seed(seed)
    	torch.manual_seed(seed)
    	torch.cuda.manual_seed(seed)
    	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    	torch.backends.cudnn.benchmark = False
    	torch.backends.cudnn.deterministic = True
    
    if __name__ == '__main__':
        set_seed(0)
        # x.shape = [5, 2]
        x = torch.tensor([[1,2], [3,4], [3,5], [4,5], [2,6]], dtype=torch.float)
        # edge_index.shape = [2, 6]
        edge_index = torch.tensor([[0,1,2,3,1,4], [1,0,3,2,4,1]])
        print("num_node: ",x.shape[0])
        print("num_edge: ",edge_index.shape[1])
        in_channels = x.shape[1]
        out_channels = 3
    
        gcn = GCNConv(in_channels, out_channels)
        out = gcn(x, edge_index)
        print(out)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91

    PyTorch固定随机数种子

    固定住 随机数种子 后,多次运行,比较好 比较 与 理解。

  • 相关阅读:
    【LeetCode】【剑指offer】【二叉搜索树的第k大节点】
    定制相亲交友系统如何提升用户体验
    学习大数据必须掌握哪些核心技术?
    基于STM32单片机一氧化碳(CO)气体监控系统proteus仿真设计
    Pytorch学习task02_待补
    IEC101、IEC103、IEC104的区别与应用场景
    使用Java Spring Boot构建高效的爬虫应用
    Stable Diffusion web UI 文档
    A Comprehensive Survey on Graph Neural Networks
    Jenkins 安装
  • 原文地址:https://blog.csdn.net/LittleSeedling/article/details/125020621