PyG(PyTorch Geometric)是一个基于PyTorch的图神经网络框架,包含图神经网络训练中的数据集处理、多GPU训练、多个经典的图神经网络模型、多个常用的图神经网络训练数据集而且支持自建数据集,主要包含以下几个模块
图对节点和边进行建模。PyG 用 torch_geometric.data.Data 可以描述保存图结构数据,默认情况下包含以下属性:
以上属性不是必要的,但是Data 对象也不限于这些属性。
我们使用PyG表示下面这个图
import torch
from torch_geometric.data import Data
# 边的连接信息 注意,无向图的边要定义两次
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# edge_index = torch.tensor([[0, 1],[1, 0],
# [1, 2],[2, 1]], dtype=torch.long)
# 节点的属性信息
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
# 实例化为一个图结构的数据
data = Data(x=x, edge_index=edge_index)
# data = Data(x=x, edge_index=edge_index.t().contiguous())
PyG包含大量常见的基准数据集,例如
接下来拿ENZYMES数据集(包含600个图,每个图分为6个类别,图级别的分类)举例如何使用PyG的公共数据集
from torch_geometric.datasets import TUDataset
# 导入数据集
dataset = TUDataset(
# 指定数据集的存储位置 如果指定位置没有相应的数据集 PyG会自动下载
root='/tmp/ENZYMES',
# 要使用的数据集名称
name='ENZYMES')
# 数据集的长度
print(len(dataset))
# 数据集的类别数
print(dataset.num_classes)
# 数据集中节点属性向量的维度
print(dataset.num_node_features)
# 600个图,我们可以根据索引选择要使用哪个图
data = dataset[0]
# 是否为无向图
data.is_undirected()
# 随机打乱数据集的一种方法
perm = torch.randperm(len(dataset))
dataset = dataset[perm]
# 随机打乱数据集的另一种方法
dataset = dataset.shuffle()
PyG 通过创建稀疏块对角邻接矩阵并在节点维度中串联特征和标签矩阵,将数据集为我们指定的batch以批处理方式进行训练,且允许在一个批处理中的图拥有不同数量的节点和边。
PyG中的torch_geometric.loader.DataLoader已经实现了过程,可以直接调用。例如:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
batch
>>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])
batch.num_graphs
>>> 32
# 计算每个图的节点维度中的平均节点特征
x = scatter_mean(data.x, data.batch, dim=0)
x.size()
>>> torch.Size([32, 21])
batch是一个列向量,它将每个节点映射到批处理中图索引:
Transforms是图转换和图增强的常用方法。
PyG 自带的转换需要 Data 对象作为输入,并返回新的转换后的 Data 对象。
可以使用torch_geometric.transforms.Compose将转换链接在一起,torchvision pre_transform transform。
例如,我们对ShapeNet 数据集(包含 17000 个 3D 形状点云和每个点来自16个形状类别的标签)应用转换
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
# 我们可以通过变换从点云生成最近邻图,从而将点云数据集转换为图形数据集:
pre_transform=T.KNNGraph(k=6),
# 将每个节点位置转换一个小数
transform=T.RandomJitter(0.01))
dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])
可以在torch_geometric.transforms上找到PyG中已实现转换的所有方法。
空域图卷积可以看作是相邻节点(和边)之间进行信息传递、融合的过程,计算公式可以一般化为
PyG 提供了 MessagePassing 基类,用户只需定义函数 message() 和 update() 以及聚合方式 aggr=“add”/“mean”/“max”。
对于以上计算过程,PyG利用MessagePassing进行实现。接下来以两篇经典图神经网络论文为例,介绍MessagePassing的使用。
Kipf和Welling的GCN层
Wang等人的EdgeConv层
在第一篇论文中,作者提出的GCN
其中,相邻节点特征先进行线性变换,按其度数规范化,最后将所有信息相加,再将偏置向量应用于聚合输出得到当前节点的特征表示。
此公式可分为以下步骤:
①将自环添加到邻接矩阵。②线性变换节点特征矩阵。③计算归一化系数。
④规范化中的节点特征。⑤汇总相邻节点特征(add聚合)。⑥偏置向量相加。
步骤 ①-③ 通常在消息传递发生之前计算。步骤 ④-⑤ 可以使用 MessagePassing 基类处理。完整层实现如下所示:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
# 定义GCN空域图卷积神经网络
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" 聚合操作 (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False) # W'T
self.bias = Parameter(torch.Tensor(out_channels)) # 偏置项
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # ①添加自环
x = self.lin(x) # ②线性变换
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # ③ 标准化参数
# ④-⑤传递聚合信息 propagate会自动调用self.message函数,并将参数传递给它
out = self.propagate(edge_index, x=x, norm=norm)
out += self.bias # 使用偏置项
return out
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
在第二篇论文中,边卷积层处理图形或点云,并在数学上定义为
h
θ
h_{\theta}
hθ表示MLP,类似于GCN层,我们使用 MessagePassing 类实现该层,使用’sum‘聚合函数。
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j): # 信息汇聚函数
tmp = torch.cat([x_i, x_j - x_i], dim=1)
return self.mlp(tmp)
在 message() 函数中,我们用于转换每个边的目标节点特征和相对源节点特征。
边卷积实际上是一种动态卷积,它使用特征空间中的最近邻重新计算每个图层的图形。幸运的是,PyG附带了一个名为torch_geometric.nn.pool.knn_graph()的GPU加速批量k-NN图生成方法:
from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv):
def __init__(self, in_channels, out_channels, k=6):
super().__init__(in_channels, out_channels)
self.k = k
def forward(self, x, batch=None):
edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
return super().forward(x, edge_index)
conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)
PyG将自建数据集分为两个文件夹:raw_dir、processed_dir。
row_dir是原始的数据集,processed_dir是PyG处理之后的数据集
对于数据集PyG有三种过滤方法—transform、pre_transform、pre_filter。
PyG为数据集提供了两个抽象类:
此外,每个数据集都可以传递transformpre_transformpre_filter函数,用于动态转换数据对象,默认为None。
包含四种方法(可以在torch_geometric.data中找到下载和提取数据的有用方法。)
通用模板:
import torch
from torch_geometric.data import InMemoryDataset, download_url
# 实现In Memory Dataset的通用模板
class MyDataset(InMemoryDataset):
# 初始化
def __init__(self, root, transfrom=None, pre_transform=None):
# root是数据集的根目录
super(MyDataset, self).__init__(root, transfrom, pre_transform)
# 加载数据集
self.data, self.slices = torch.load(self.processed_paths[0])
def raw_file_names(self) -> Union[str, List[str], Tuple]:
return ['file_1', 'file_2', ...]
def processed_file_names(self) -> Union[str, List[str], Tuple]:
return ['data.pt']
def download(self):
# 将数据集下载到raw_dir文件夹中
download_url(url, self.raw_dir)
def process(self):
data_list = [...]
# 进行数据过滤
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
# self.collate将所有数据组合在一起,加速存储
# data是组合之后的数据
# slices是分割方式,告诉PyG如何将data还原为原先的数据
data, slices = self.collate(data_list)
# 保存数据
torch.save((data, slices), self.processed_paths[0])
上面需要做的几件事的基础上还需要实现
通用模板
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyDataset(Dataset):
# 初始化
def __init__(self, root, transform=None, pre_transform=None):
super(MyDataset, self).__init__(root, transform, pre_transform)
def raw_file_names(self) -> Union[str, List[str], Tuple]:
return ['file_1', 'file_2', ...]
def processed_file_names(self) -> Union[str, List[str], Tuple]:
return ['data_1.pt', ...]
def download(self):
path = download_url(url, self.raw_dir)
def process(self):
i = 0
for raw_path in self.raw_paths:
# 读取数据
data = Data(...)
# 过滤数据集
if self.pre_filter is not None and not self.pre_filter(data):
pass
if self.pre_transform is not None:
data = self.pre_transform(data)
# 保存数据
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self,idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
在这里,每个图形数据对象都单独保存在 process() 中,并手动加载到 get() 中。
如何跳过执行download()和process()?
您可以通过不重写download()和process()。
我真的需要使用这些数据集接口吗?
不!您不必使用数据集,例如,当您想要动态创建合成数据而不将其保存时。在这种情况下,只需传递一个包含torch_geometric.data.Data 对象的常规 python 列表,并将它们传递给 torch_geometric.loader.DataLoader:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
自己的例子
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['DD_A.txt','DD_graph_indicator.txt','DD_graph_labels.txt','DD_node_labels.txt']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
pass
return
def process(self):
# Read data into huge `Data` list.
self.data, self.slices = read_tu_data(self.raw_dir,'DD')
if self.pre_filter is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [data for data in data_list if self.pre_filter(data)]
data, slices = self.collate(data_list)
if self.pre_transform is not None:
data_list = [self.get(idx) for idx in range(len(self))]
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((self.data, self.slices), self.processed_paths[0])
path = osp.join(osp.dirname(osp.abspath('')), 'DD')
dataset = MyOwnDataset(path)
data_loader = DataLoader(dataset, batch_size=128)