• PyTorch Geometric Temporal 介绍 —— 数据结构和RGCN的概念


    Introduction

    PyTorch Geometric Temporal is a temporal graph neural network extension library for PyTorch Geometric.

    PyTorch Geometric Temporal 是基于PyTorch Geometric的对时间序列图数据的扩展。

    Data Structures: PyTorch Geometric Temporal Signal

    定义:在PyTorch Geometric Temporal中,边、边特征、节点被归为图结构Graph,节点特征被归为信号Single,对于特定时间切片或特定时间点的时间序列图数据被称为快照Snapshot。

    PyTorch Geometric Temporal定义了数个Temporal Signal Iterators用于时间序列图数据的迭代。

    Temporal Signal Iterators数据迭代器的参数是由描述图的各个对象(edge_index,node_feature,...)的列表组成,列表的索引对应各时间节点。

    按照图结构的时间序列中的变换部分不同,图结构包括但不限于为以下几种:

    • Static Graph with Temporal Signal
      静态的边和边特征,静态的节点,动态的节点特征
    • Dynamic Graph with Temporal Signal
      动态的边和边特征,动态的节点和节点特征
    • Dynamic Graph with Static Signal
      动态的边和边特征,动态的节点,静态的节点特征

    理论上来说,任意描述图结构的对象都可以根据问题定为静态或动态,所有对象都为静态则为传统的GNN问题。

    实际上,在PyTorch Geometric Temporal定义的数据迭代器中,静态和动态的差别在于是以数组的列表还是以单一数组的形式输入,以及在输出时是按索引从列表中读取还是重复读取单一数组。

    如在StaticGraphTemporalSignal的源码中_get_edge_index _get_features分别为:

    # https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/signal/static_graph_temporal_signal.html#StaticGraphTemporalSignal
    
    def _get_edge_index(self):
    	if self.edge_index is None:
    		return self.edge_index
    	else:
    		return torch.LongTensor(self.edge_index)
    		
    def _get_features(self, time_index: int):
    	if self.features[time_index] is None:
        	return self.features[time_index]
        else:
        	return torch.FloatTensor(self.features[time_index])
    

    对于Heterogeneous Graph的数据迭代器,其与普通Graph的差异在于对于每个类别建立键值对组成字典,其中的值按静态和动态定为列表或单一数组。

    Recurrent Graph Convolutional Layers

    Define $\ast_G $ as graph convolution, \(\odot\) as Hadamard product

    \[\begin{aligned} &z = \sigma(W_{xz}\ast_Gx_t+W_{hz}\ast_Gh_{t-1}),\\ &r = \sigma(W_{xr}\ast_Gx_t+W_{hr}\ast_Gh_{t-1}),\\ &\tilde h = \text{tanh}(W_{xh}\ast_Gx_t+W_{hh}\ast_G(r\odot h_{t-1})),\\ &h_t = z \odot h_{t-1} + (1-z) \odot \tilde h \end{aligned} \]

    From https://arxiv.org/abs/1612.07659

    具体的函数实现见 https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html#

    与RNN的比较

    \[\begin{aligned} &z_t = \sigma(W_{xz}x_t+b_{xz}+W_{hz}h_{t-1}+b_{hz}),\\ &r_t = \sigma(W_{xr}x_t+b_{xr}+W_{hr}h_{t-1}+b_{hr}),\\ &\tilde h_t = \text{tanh}(W_{xh}x_t+b_{xh}+r_t(W_{hh}h_{t-1}+b_{hh})),\\ &h_t = z*h_{t-1} + (1-z)*\tilde h \end{aligned} \]

    From https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU

    对于传统GRU的解析 https://zhuanlan.zhihu.com/p/32481747

    在普通数据的Recurrent NN中,对于每一条时间序列数据会独立的计算各时间节点会根据上一时间节点计算hidden state。但在时间序列图数据中,每个snapshot被视为一个整体计算Hidden state matrix \(H \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_H}\) 和Cell state matrix(对于LSTM)\(C \in \mathbb{R}^{\text{Num(Nodes)}\times \text{Out_Channels}_C}\)

    与GCN的比较

    相较于传统的Graph Convolution Layer,RGCN将图卷积计算的扩展到RNN各个状态的计算中替代原本的参数矩阵和特征的乘法计算。

  • 相关阅读:
    7、2pc、3pc协议
    基于阿里云 Serverless 快速部署 function 的极致体验
    MeshLab相关&纹理贴图
    Java —— String类
    kubernetes技术
    2311rust,到38版本更新
    【web-攻击用户】(9.4)跨域捕获数据——通过注入HTML捕获数据、注入CSS捕获数据、JavaScript劫持
    滚珠螺杆应如何存放避免受损
    05 - 雷达的发展与应用
    通过电商API接口分析电商平台客户消费行为
  • 原文地址:https://www.cnblogs.com/yc0806/p/16931208.html