• Pyg消息传递源码(MESSAGE PASSING)+实例


    1. MessagePassing基类

    GNN关键的步骤就是消息传递、聚合、更新。
    pytorch geometric提供了一个MessagePassing基类,它已经通过MessagePassing.propagate()实现了以上三步对应的计算过程。
    我们只需定义一个继承了MessagePassing基类的class,然后根据具体的图算法来更新函数 message() 的邻域聚合方式aggr=“add”, aggr=“mean” or aggr=“max”,以及函数update(),并在自定义的图算法卷积层中的forward函数里面调用progagate函数就可以。
    大致流程如下:

    import torch
    from torch_geometric.nn import MessagePassing
    
    class MyConv(MessagePassing): # 定义继承了MessagePassing基类的class
        def __init__(self, in_channels, out_channels, **kwargs):
            kwargs.setdefault('aggr', 'add')  # 邻域聚合方式
            super(MyConv, self).__init__(**kwargs)
            ...
    
        def forward(self, x, edge_index):
        	...
            return self.propagate(edge_index, **kwargs)
    
        def message(self, **kwargs):
        	...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    2. Message 源码

    2.1 MessagePassing初始化

    def __init__(self, aggr: Optional[str] = "add",
               flow: str = "source_to_target", node_dim: int = -2,
               decomposed_layers: int = 1):
    
    • 1
    • 2
    • 3

    aggr: 邻域聚合方式,默认add,还可以是mean, max。
    flow: 消息传递方向,默认从source_to_target,也可以设置为target_to_source。
    node_dim: 定义沿着哪个维度进行消息传递,默认-2,因为-1是特征维度。

    2.2 MessagePassing.propagate

     MessagePassing.propagate(edge_index, size=None, **kwargs)
    
    • 1

    progagate会依次调用messageaggregateupdate方法。如果edge_index是SparseTensor,会优先调用message_and_aggregate方法来代替messageaggregate方法。

    edge_index:它有两种形式Tensor和SparseTensor。Tensor形式下的edge_index的shape是(2, N);SparseTensor则可以理解为稀疏矩阵的形式存储边信息。
    size:当size为None的时候,默认邻接矩阵是方形[N, N]。如果是异构图(如bipartite图),图中的两类点的特征和index是相互独立的。通过传入size=(N, M),x=(x_N, x_M)时,propagate可以处理这种情况。
    kwargs:图卷积计算过程的额外所需的信息,都可以通过kwargs传入。

    2.3 MessagePassing.message()

    这个方法在 flow=“source_to_target” 的设置下,计算了邻居节点 j 到中心节点 i 的消息。传给propagate()所有参数都可以传递给message(),而且传递给propagate()的tensors可以通过加上_i或_j的后缀来mapping到对应的节点

    def message(self, x_j: Tensor) -> Tensor:
        return x_j
    
    • 1
    • 2

    x_j:代表了邻居的特征,通过edge_index中邻居节点去索引对应位置的x得到

    当edge_index的shape是(2, N_edges),x的shape是(N_nodes, N_features),则得到的x_j的shape是(N_edges, N_features)

    例如:
    edge_index:tensor([[1, 2, 3, 3], [0, 0, 0, 1]])
    x:tensor([[0, 1], [2, 3], [4, 5], [6, 7]])
    邻居节点 j 的index是edge_index的第一个元素[1,2,3,3],根据节点 j 的index [1, 2, 3, 3],去索引x对应的位置,则得到

    x_j = x[index(j)]=x[[1,2,3,3]] = tensor([[2,3],[4,5],[6,7],[6,7]])
    
    • 1

    2.4 MessagePassing.aggregate(inputs, index, …)

    这个方法实现了邻域的聚合,pytorch geometric通过scatter共实现了三种方式mean、 add、max。一般来说,比较通用的图算法,GCN、GraphSAGE、GAT都不需要自己再额外定义aggregate方法。

    2.5 MessagePassing.update(aggr_out, …)

    之前传入propagate的参数也都传入update。对应每个中心节点 i ,根据aggregate的邻域结果和传入propagate的参数中选择所需信息,更新节点 i 的embedding。

    2.6 MessagePassing.message_and_aggregate(adj_t, …)

    前面提到pytorch geometric中的边信息有Tensor和SparseTensor两种形式。
    SparseTensor提供了矩阵存储形式,以稀疏矩阵方式存储,message_and_aggregate则提供了邻域聚合的矩阵计算方式〈不是所有的图卷积都可以用矩阵计算)。

    当边是以SparseTensor存储的时候,propagate会优先去查找是否实现了message_and_aggregate如果已经实现了,就会调用message_and_aggregate来代替messageaggregate。如果没有实现, propagate需要将边信息转换为Tensor,然后再调用messageaggregate
    message_and_aggregate是需要自己实现的,只有实现了它,才可以发挥矩阵计算的优势。

    3 实例

    import torch
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree
    
    class GCNConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
            self.lin = torch.nn.Linear(in_channels, out_channels)
        def forward(self, x, edge_index):
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
            x = self.lin(x)
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
            return self.propagate(edge_index, x=x, norm=norm)
        def message(self, x_j, norm):
            return norm.view(-1, 1) * x_j
    
    from torch_geometric.datasets import Planetoid
    dataset = Planetoid(root='dataset/Cora', name='Cora')
    data = dataset[0]
    net = GCNConv(data.num_features, 64)
    h_nodes = net(data.x, data.edge_index)
    print(h_nodes.shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
  • 相关阅读:
    Java面向对象进阶7——代码块
    软件测试需要学习什么 3分钟带你了解软测的学习内容
    专利申请流程专利下证要多长时间实用新型专利申请
    【Web安全】注入攻击
    html+css+js实现简单的交互效果
    Python基础——文件系统(os模块和os.path模块)
    一文解决什么是Docker。如何使用Docker。Docker能做什么。
    工具 - markdown编辑器常用方法
    OpenCV轻松入门(九)——使用第三方库imgaug自定义数据增强器
    小项目-词法分析器
  • 原文地址:https://blog.csdn.net/weixin_45928096/article/details/126805227