• 在PyG上构建自己的数据集


    PyG构建自己数据集

    PyG简介

    PyG(PyTorch Geometric)是一个建立在 PyTorch 基础上的库,用于轻松编写和训练图神经网络(GNN),用于与结构化数据相关的广泛应用。

    它包括在图和其他不规则结构上进行深度学习的各种方法,也被称为几何深度学习,来自各种已发表的论文。此外,它还包括易于使用的迷你批量加载器(mini-batch loaders),用于在许多小型和单一的巨型图形上操作;多 GPU 支持、大量常见的基准数据集(基于简单的接口来创建你自己的数据集);以及有用的变换,既可以在任意图形上学习,也可以在 3D 网格或点云上学习。

    数据集介绍

    本部分用到的也是Cora数据集,但是不是官方版本的数据集,而是非常平易近人的风格,拿来就可以使用,格式如下:
    cora.cites
    在这里插入图片描述
    cora.cites文件格式非常简单,就是两列,代表两个具备边关系的节点。
    cora.content
    在这里插入图片描述
    在这里插入图片描述
    cora.content文件内容也很简单,第一列是节点id,最后一列是每个节点的标签,中间的数值是每个节点的特征值。

    代码实现

    PyG构建数据集,氛围两类,一种是针对小数据集的in_memory_dataset,这种形式可以直接将所用的数据集都加载到内存当中;另一种是针对大数据集的Dataset,这种形式主要是可以对大数据集进行索引,进行batch合并,减少每次内存的数据量。实际业务中,我们大多是用大数据集,因此,就以这个作为例子。

    from torch_geometric.data import Dataset, Data
    # 定义自己的数据集类
    class mydataset(Dataset):
        def __init__(self, root, transform=None, pre_transform=None):
            super(mydataset, self).__init__(root, transform, pre_transform)
    
        # 原始文件位置
        @property
        def raw_file_names(self):
            return ['cora.content', 'cora.cites']
    
        # 文件保存位置
        @property
        def processed_file_names(self):
            return 'data.pt'
    
        def download(self):
            pass
    
        # 数据处理逻辑
        def process(self):
            idx_features_labels = np.genfromtxt(self.raw_paths[0])
            x = idx_features_labels[:, 1:-1]
            x = torch.tensor(x, dtype=torch.float32)
            y, label_dict = self.encode_labels(np.genfromtxt(self.raw_paths[0], dtype='str', usecols=(-1,)))
            y = torch.tensor(y)
            idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
            id_node = {j: i for i, j in enumerate(idx)}
    
            edges_unordered = np.genfromtxt(self.raw_paths[1], dtype=np.int32)
            edge_str = [id_node[each[0]] for each in edges_unordered]
            edge_end = [id_node[each[1]] for each in edges_unordered]
            edge_index = torch.tensor([edge_str, edge_end], dtype=torch.long)
    
            data = Data(x=x, edge_index=edge_index, y=y)
    
            torch.save(data, os.path.join(self.processed_dir, f'data.pt'))
    
        def encode_labels(self, labels):
            classes = sorted(list(set(labels)))
            labels_id = [classes.index(i) for i in labels]
            label_dict = {i: c for i, c in enumerate(classes)}
            return labels_id, label_dict
    
        # 定义总数据长度
        def len(self):
            idx_features_labels = np.genfromtxt(self.raw_paths[0], dtype=np.int32)
            uid = idx_features_labels[:, 0:1]
            return len(uid)
    
        # 定义获取数据方法
        def get(self, idx):
            data = torch.load(os.path.join(self.processed_dir, f'data.pt'))
            return data
    dataset = mydataset('../data/')
    data = dataset[0].to(device)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    首先,我们定义了自己的一个类,mydataset类,其继承了一个父类-Dataset,这个Dataset类是PyG框架自己定义好的,其中包括数据集下载、数据预处理、数据文件保存、数据检索等等功能,大家可以详细了解一下,我们只对用到的进行解释。

    # 原始文件位置
    @property
    def raw_file_names(self):
        return ['cora.content', 'cora.cites']
    
    • 1
    • 2
    • 3
    • 4

    raw_file_names:指向自己的文件目录下的文件名,这个可以将你用到的文件按照列表的形式进行展现,如果用cora.content,那就是0,用cora.cites,那就是1;

    @property
    def processed_file_names(self):
        return 'data.pt'
    
    • 1
    • 2
    • 3

    processed_file_names:指向处理后的数据文件保存文件名称,可以在下次加载数据的时候,直接读取该文件;

    def download(self):
        pass
    
    • 1
    • 2

    download:该函数是需要去下载数据集的,因为我们是自建数据集,因此,不用;

    def process(self):
    	#读取cora.content文件
        idx_features_labels = np.genfromtxt(self.raw_paths[0])
        #获取节点特征
        x = idx_features_labels[:, 1:-1]
        #转为tensor,并指定数据类型
        x = torch.tensor(x, dtype=torch.float32)
        #获取每个节点的标签
        y, label_dict = self.encode_labels(np.genfromtxt(self.raw_paths[0], dtype='str', usecols=(-1,)))
        #tensor化
        y = torch.tensor(y)
        #获取每个节点
        idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
        #将每个节点映射为id(从0开始)
        id_node = {j: i for i, j in enumerate(idx)}
    	#读取cora.cites
        edges_unordered = np.genfromtxt(self.raw_paths[1], dtype=np.int32)
        #获取每个节点对应的id
        #第一列节点-->id
        edge_str = [id_node[each[0]] for each in edges_unordered]
        #第二列节点-->id
        edge_end = [id_node[each[1]] for each in edges_unordered]
        #将边转为tensor
        edge_index = torch.tensor([edge_str, edge_end], dtype=torch.long)
    	#将所有数据加载至Data对象中
        data = Data(x=x, edge_index=edge_index, y=y)
    	#保存处理好的图数据,下次可以直接加载
        torch.save(data, os.path.join(self.processed_dir, f'data.pt'))
    
    def encode_labels(self, labels):
        classes = sorted(list(set(labels)))
        labels_id = [classes.index(i) for i in labels]
        label_dict = {i: c for i, c in enumerate(classes)}
        return labels_id, label_dict
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    process:该函数是处理数据的逻辑函数,大家可以将处理数据的逻辑放在该函数中,主要是节点特征、节点标签、以及边的构成;
    self.raw_paths:这个是raw_file_names返回的列表和文件路径拼接之后的结果,就是将文件名扩展为路径+文件名;

    # 定义总数据长度
    def len(self):
        idx_features_labels = np.genfromtxt(self.raw_paths[0], dtype=np.int32)
        uid = idx_features_labels[:, 0:1]
        return len(uid)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    len:获取总数据的长度,为了进行数据分割做准备,可以自己定义;

    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data.pt'))
        return data
    
    • 1
    • 2
    • 3

    get:制定获取图数据的方式,可以自己定义。

    数据输出

    在这里插入图片描述
    我们可以看到,Data是一个包含所有属性的对象。
    x:是27081433的矩阵,即2708个节点,每个节点有1433维;
    edge_index:是一个2
    5429的矩阵,表示共有5429条边;
    y:表示节点的标签,共2708个节点。

    数据集划分

    我们构建好了自己的数据集格式,但是,进行训练的时候,必须有训练集、验证集和测试集,这块我曾经自己进行实现过,但是,实现起来比较复杂,这个时候发现,原来PyG框架,也把这块给实现了,还是很方便的。

    data = T.RandomNodeSplit()(data)
    
    • 1

    在这里插入图片描述
    我们可以看一下RandomNodeSplit,顾名思义,就是随机划分节点,是不是很简单,该函数可以自己划分数据集,自己也可以指定每个数据集的比例,替换其中的参数即可。
    在这里插入图片描述
    当我们加载完之后,可以看出Data对象中多出来三个,分别是train_mask、val_mask、test_mask,输出看的话,每个都是2708个,但是不同位置上有不同的bool值,就是为了表示该节点是否是训练集、验证集或者测试集。

    结语

    整体看下来,是不是对于PyG处理数据集有所了解呢,以上已经经过小编的实际运行啦,大家可以拿来改改,用在自己的开发数据集上。
    当然,如果有问题或者需要补充的地方,大家可以随时联系我,QQ:1143948594。

  • 相关阅读:
    【JVM技术专题】GC问题分析和故障排查规划指南「实战篇」
    Meta Llama 3本地部署
    我的Qt作品(15)使用Qt+OpenCV实现一个卡尺工具,具备找线和找圆的功能
    Linux环境变量
    【图像去噪】基于边缘增强扩散 (cEED) 和 Coherence Enhancing Diffusion (cCED) 滤波器实现图像去噪附matlab代码
    DL-24C/2A电流继电器
    Java版企业电子招标采购系统源码Spring Cloud + Spring Boot +二次开发+ MybatisPlus + Redis
    设计模式-12-策略模式
    校园小程序毕业设计,学校小程序设计与实现,毕设作品参考
    通义千问AI+Java
  • 原文地址:https://blog.csdn.net/qq_32113189/article/details/126663738