• 【GNN】采样算法总结


    目录

    背景:

    模型:

    1. 按层采样的方法(FastGCN & AS-GCN)

    2. 子图采样的方法(ClusterGCN & GraphSAINT)

    3. 控制变量的方式(S-GCN)

    实验结果:

    参考文献:


    背景:

    许多领域的数据都可以自然地表达成图结构,比如社交网络、蛋白质相互作用网、化学分子图和3D点云等。这些图数据的复杂性对现有的机器学习算法提出了重大挑战。近几年图卷积网络(GCN)把深度学习中卷积神经网络的思想用到图的学习上,在点分类、边预测以及推荐系统上都取得了非常好的效果。

    在GCN中,为了得到根节点的表达,需要从它上一层的邻居节点里收集信息。然后同样的,这些邻居节点也需要再从他们上一层的邻居节点里收集信息。这样随着层数的加深,需要计算的多跳邻居数量就会指数级上升。这种“邻居数量指数增加”的问题严重影响了GCN在大规模数据场景下的应用。为了解决上述问题,近两年很多加速图卷积网络训练的方法被提出来,我们将其归纳成三类:1. 按层采样的方法,2. 子图采样的方法,3. 控制变量(control variate)的方法。

    下面会具体介绍这三类加速训练方法。

    模型:

    1. 按层采样的方法(FastGCN & AS-GCN)

    按层采样方法的核心思想是,限制每一层中采样邻居的总数量,这样多跳邻居数量随着层数增加只会线性上升,从而加速GCN模型的训练速度。具体采样方法如下图所示,例如一个2层的GCN模型,在 [公式] 层中我们随机选取batch_size个根节点,在之后的 [公式] 和 [公式] 层中我们限制了每层的采样点数,并保留相邻两层采样点之前的边。

    接着我们需要设计按层采样的策略,以保证采样过程是无偏的,同时尽量减小采样带来的方差。具体来讲对于一个 [公式] 层 的GCN模型,在第 [公式] 层上节点 [公式] 的聚合过程如下

    [公式]

    其中 [公式] 是归一化的邻接矩阵,如果 [公式][公式]之间有边相连则 [公式] ,否则 [公式] ;[公式] 是节点总数; [公式] 是第 [公式] 层的聚合参数。聚合计算公式改写成概率期望的形式,

    [公式]

    其中 [公式] ,[公式] 是节点 [公式]均匀采样邻居节点的概率。

    假设我们在第 [公式] 层上有放回地均匀采样 [公式] 个节点[公式]。采样后节点的聚合过程为

    [公式]

    为了保证按层采样是无偏的,我们使用重要性采样(Importance Sampling)的技巧,

    [公式]

    [公式]

    其中 [公式] 是给定上一层 [公式] 个节点 [公式] 然后采样下一层节点的概率。

    至止我们已经得到了一个无偏的按层采样策略,下面我们需要找到最优的按层采样概率,使得采样方差 [公式] 尽量小,

    [公式]

    根据上述公式,我们可以求得最优的采样概率 [公式] 为

    [公式]

    但是这里 [公式] 是无法得到的,因为计算 [公式] 需要使用 [公式] 的信息,但是实际采样的时候,是先采样得到 [公式] 层再采样得到 [公式] 层,这样就会遇到chicken-and-egg的问题。

    为了解决这个问题,FastGCN假设[公式]的值与 [公式] 正相关,这样采样概率变为

    [公式]

    为了得到按层采样概率,分子分母同时对 [公式] 求和得到采样概率 [公式] 。

    论文中完整的算法描述如下

    AS-GCN采样另一种方式来解决[公式] 无法得到的问题,他定义了一个线性函数 [公式] 来近似[公式]的真实值。然后按层采样概率变为

    [公式]

    然后为了优化线性函数的参数,在将方差 [公式] 加入到GCN模型的loss中一起训练优化。

    2. 子图采样的方法(ClusterGCN & GraphSAINT)

    与按层采样的思路不同,子图采样的方法通过限制子图的大小来解决“邻居数量指数增加”的问题。子图采样的代表方法ClusterGCN使用图聚类算法(比如Metis和Graclus)将全图切割成若干个小的cluster,训练时随机选取q个clusters组成子图,再在采样子图上进行full GCN计算。

    如上图所示的4层GCN模型,左边是不做采样的方式,红色根节点在扩展4层之后邻居数量会指数上升。右边是ClusterGCN的方式,红色根节点扩展的邻居数量最多为它所在的子图大小。

    论文中完整的算法描述如下

    ClusterGCN的思路相同,GraphSAINT也是先采样子图,然后在得到的子图上进行full GCN的计算。不同的是,后者显式地考虑了子图采样对GCN计算带来的偏差,可以保证采样后节点的聚合过程是无偏的,并且使采样带来的方差尽量小。

    具体来说,在采样子图 [公式] 中,GraphSAINT设计节点 [公式] 的聚合过程如下

    [公式] ,

    其中 [公式] 是聚合归一化参数 [公式] 是归一化的邻接矩阵,[公式] (当 [公式] 时,[公式] ;当 [公式] 时, [公式] ;当 [公式] 时, [公式] 未定义), [公式] 是 [公式] 层的聚合参数。

    论文证明如果 [公式] ,那么 [公式] 是节点 [公式] 聚合结果的无偏估计。其中 [公式] 是采样边 [公式] 的概率, [公式] 是采样点 [公式] 的概率。

    得到无偏估计 [公式] 之后,我们希望这个估计的方差尽量小。论文证明在采样包含 [公式] 条边的子图时,每条边 [公式] 按照概率 [公式] 进行采样可以让方差最小,其中 [公式] 。由于 [公式] 需要计算 [公式] ,为了简化采样复杂度,我们忽略 [公式] 项,则采样概率 [公式] .

    论文中完整的算法描述如下

    3. 控制变量的方式(S-GCN)

    这里控制变量(control variate)是蒙特卡洛估计中常用的一种减小估计方差的方法。其核心思想是找到一个合适的随机变量,使得原估计减去这个随机变量之后的方差可以变小。在S-GCN方法中,我们保存每个节点在每一层上的历史embedding值 [公式] ,作为其真实值 [公式] 的一个合理近似。因为如果训练过程中模型参数没有变化太快的话, [公式] 会比较小。

    具体来说,我们在节点 [公式] 的聚合过程中,使用 [公式] 替换 [公式] 得到
    [公式]

    其中 [公式] 是节点 [公式] 的邻居集合, [公式] 是采样的邻居集合, [公式] 的节点个数为 [公式] 。论文中证明了这个估计方式是无偏的,并且方差比随机采样的方式小。

    论文中完整的算法描述如下

    实验结果:

    论文[4]对比了上述五种加速训练方法,下面表格是这些方法在五个开源数据测试集上F1-micro分数。从测试结果可以看出GraphSAINT的模型效果最佳。

    其次论文[4]比较了这些方法的模型收敛速度,S-GCN在PPI上收敛速度最快,GraphSAINT在其余4个数据上有最快的收敛速度。

    参考文献:

    [1] Jie Chen, Tengfei Ma, and Cao Xiao. Fastgcn: Fast learning with graph convolutional networks via importance sampling. In International Conference on Learning Representations (ICLR) , 2018b.

    [2] Wenbing Huang, Tong Zhang, Yu Rong, and Junzhou Huang. Adaptive sampling towards fast graph representation learning. In Advances in Neural Information Processing Systems , pp. 4558–4567, 2018.

    [3] Wei-Lin Chiang, Xuanqing Liu, Si Si, Yang Li, Samy Bengio, and Cho-Jui Hsieh. Cluster-gcn: An efficient algorithm for training deep and large graph convolutional networks. CoRR , abs/1905.07953, 2019.

    [4] Zeng, Hanqing, Hongkuan Zhou, Ajitesh Srivastava, Rajgopal Kannan, and Viktor Prasanna. "Graphsaint: Graph sampling based inductive learning method."arXiv preprint arXiv:1907.04931(2019).

    [5] Jianfei Chen, Jun Zhu, and Le Song. Stochastic training of graph convolutional networks with variance reduction. In ICML , pp. 941–949, 2018a.

  • 相关阅读:
    慢 SQL 分析与优化
    linux中操作服务器常用命令
    二叉树的实现(C语言数据结构)
    Excel INDEX MATCH教程之 什么是INDEX MATCH,有什么用(教程含案例)
    Python 进阶 - 日常工作中使用过的简单Trick
    vivo 制品管理在 CICD 落地实践
    复习Day07:链表part03:21. 合并两个有序链表、2. 两数相加
    RabbitMq安装(Erlang前置安装)
    基于k近邻算法的干豆品种分类
    Django框架之python后端框架介绍
  • 原文地址:https://blog.csdn.net/lj2048/article/details/106540805