GNN关键的步骤就是消息传递、聚合、更新。
pytorch geometric提供了一个MessagePassing基类,它已经通过MessagePassing.propagate()实现了以上三步对应的计算过程。
我们只需定义一个继承了MessagePassing基类的class,然后根据具体的图算法来更新函数 message() 的邻域聚合方式aggr=“add”, aggr=“mean” or aggr=“max”,以及函数update(),并在自定义的图算法卷积层中的forward函数里面调用progagate函数就可以。
大致流程如下:
import torch
from torch_geometric.nn import MessagePassing
class MyConv(MessagePassing): # 定义继承了MessagePassing基类的class
def __init__(self, in_channels, out_channels, **kwargs):
kwargs.setdefault('aggr', 'add') # 邻域聚合方式
super(MyConv, self).__init__(**kwargs)
...
def forward(self, x, edge_index):
...
return self.propagate(edge_index, **kwargs)
def message(self, **kwargs):
...
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2,
decomposed_layers: int = 1):
aggr
: 邻域聚合方式,默认add,还可以是mean, max。
flow
: 消息传递方向,默认从source_to_target,也可以设置为target_to_source。
node_dim
: 定义沿着哪个维度进行消息传递,默认-2,因为-1是特征维度。
MessagePassing.propagate(edge_index, size=None, **kwargs)
progagate会依次调用message
、aggregate
、update
方法。如果edge_index是SparseTensor,会优先调用message_and_aggregate
方法来代替message
和aggregate
方法。
edge_index
:它有两种形式Tensor和SparseTensor。Tensor形式下的edge_index的shape是(2, N);SparseTensor则可以理解为稀疏矩阵的形式存储边信息。
size
:当size为None的时候,默认邻接矩阵是方形[N, N]。如果是异构图(如bipartite图),图中的两类点的特征和index是相互独立的。通过传入size=(N, M),x=(x_N, x_M)时,propagate可以处理这种情况。
kwargs
:图卷积计算过程的额外所需的信息,都可以通过kwargs传入。
这个方法在 flow=“source_to_target” 的设置下,计算了邻居节点 j 到中心节点 i 的消息。传给propagate()所有参数都可以传递给message(),而且传递给propagate()的tensors可以通过加上_i或_j的后缀来mapping到对应的节点。
def message(self, x_j: Tensor) -> Tensor:
return x_j
x_j
:代表了邻居的特征,通过edge_index中邻居节点去索引对应位置的x得到
当edge_index的shape是(2, N_edges),x的shape是(N_nodes, N_features),则得到的x_j的shape是(N_edges, N_features)
例如:
edge_index
:tensor([[1, 2, 3, 3], [0, 0, 0, 1]])
x
:tensor([[0, 1], [2, 3], [4, 5], [6, 7]])
邻居节点 j 的index是edge_index的第一个元素[1,2,3,3],根据节点 j 的index [1, 2, 3, 3],去索引x对应的位置,则得到
x_j = x[index(j)]=x[[1,2,3,3]] = tensor([[2,3],[4,5],[6,7],[6,7]])
这个方法实现了邻域的聚合,pytorch geometric通过scatter共实现了三种方式mean、 add、max。一般来说,比较通用的图算法,GCN、GraphSAGE、GAT都不需要自己再额外定义aggregate方法。
之前传入propagate的参数也都传入update。对应每个中心节点 i ,根据aggregate的邻域结果和传入propagate的参数中选择所需信息,更新节点 i 的embedding。
前面提到pytorch geometric中的边信息有Tensor和SparseTensor两种形式。
SparseTensor
提供了矩阵存储形式,以稀疏矩阵方式存储,message_and_aggregate
则提供了邻域聚合的矩阵计算方式〈不是所有的图卷积都可以用矩阵计算)。
当边是以SparseTensor
存储的时候,propagate
会优先去查找是否实现了message_and_aggregate
如果已经实现了,就会调用message_and_aggregate
来代替message
和aggregate
。如果没有实现, propagate
需要将边信息转换为Tensor,然后再调用message
和aggregate
。
message_and_aggregate
是需要自己实现的,只有实现了它,才可以发挥矩阵计算的优势。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)