- import torch
- from torch_geometric.data import Data
-
- x = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.float) # 节点特征矩阵(三个节点,每个节点两个特征)
- edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) # 边索引矩阵(四条边,每条边包含两个节点索引)
- y = torch.tensor([0, 1, 0], dtype=torch.long) # 每个节点的目标标签
-
- train_mask = torch.tensor([True, False, True]) # 训练掩膜(三个节点)
- test_mask = torch.tensor([False, True, False]) # 测试掩膜(三个节点)
-
- data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)
- print(data)