• 针对DGL的few-shot数据集划分方法


    目标

    针对DGL数据集做few-shot问题时候,需要将数据集划分成nshot的train,val,test。要求划分出多个task,每个task的train,val,test比例一致。同一个task内部,train,val,test无交叉;task间,train,val,test可以交叉。

    Graph-Level

    数据格式

    以MUTAG为例:

    {'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(0)}
    
    • 1
    • 2
    • 3

    label代表graph的类别

    代码

    #按照label做升序
    def cmp(a,b):
        if a['label']<b['label']:
            return -1
        if a['label']>b['label']:
            return 1
        return 0
    
    def few_shot_split_graphlevl(dataset,train_shotnum,val_shotnum,classnum,tasknum):
        #task中的train,val,test无交叉;但不同task之间的train,val,test可以交叉
        #train_shotnum代表train中的shotnum,val_shotnum同理
        #classnum代表总共有几类
        #先将dataset中的数据按照class排序,然后随机从中选出数据来放入train val,剩下的放进test即可
        train=[]
        val=[]
        test=[]
        dataset=sorted(dataset,key=functools.cmp_to_key(cmp))
        length=len(dataset)
        #统计每类各有多少张图
        classcount=torch.zeros(classnum)
        #统计每类的第一个元素在dataset中的索引位置
        class_start_index=torch.zeros(classnum)
        label_before=1e6
        count=0
        for data in dataset:
            classcount[data['label']]+=1
            if label_before != data['label']:
                label_before=data['label']
                class_start_index[data['label']]=count
            count+=1
        #print('classcount:',classcount)
        #print(class_start_index)
        class_start_index=class_start_index.int()
        for task in range(tasknum):
            train_index=[]
            val_index=[]
            test_index=list(range(0,length))
            for c in range(classnum):
                if c!=classnum-1:
                    index=random.sample(range(class_start_index[c],class_start_index[c+1]-1), train_shotnum+val_shotnum)
                else:
                    #从每类中的元素中选出train_shotnum+val_shotnum的元素,再将这些元素按照所需数目分别加入train和val中
                    index=random.sample(range(class_start_index[c],length-1), train_shotnum+val_shotnum)
                train_index=train_index+index[0:train_shotnum]
                val_index=val_index+index[train_shotnum:len(index)]
            train.append([dataset[i]for i in train_index])
            val.append([dataset[i]for i in val_index])
            #剩下的元素全部加入test
            train_val_index=train_index+val_index
            train_val_index.sort(reverse=True)
            for i in train_val_index:
                test_index.pop(i)
            test.append([dataset[i]for i in test_index])
        return train,val,test
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    输出部分结果

        dataset = GraphAdjDataset(list())
        dataset.load(os.path.join(train_config["save_data_dir"], "train_dgl_dataset.pt"))
        print(dataset[0])
        print(dataset[0]['label'])
        trainset,valset,testset=few_shot_split_graphlevl(dataset,4,2,2,2)
        print(trainset[0])
        print(len(trainset[1]))
        print(len(valset[1]))
        print(len(testset[1]))
    >>>{'id': 'G_N21_E48_NL3_EL3_121', 'graph': Graph(num_nodes=21, num_edges=48,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(1)}
    tensor(1)
    [{'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N17_E38_NL3_EL3_47', 'graph': Graph(num_nodes=17, num_edges=38,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N22_E50_NL3_EL3_162', 'graph': Graph(num_nodes=22, num_edges=50,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N11_E22_NL3_EL3_131', 'graph': Graph(num_nodes=11, num_edges=22,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N16_E34_NL3_EL3_145', 'graph': Graph(num_nodes=16, num_edges=34,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N13_E28_NL3_EL3_1', 'graph': Graph(num_nodes=13, num_edges=28,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N12_E24_NL3_EL3_80', 'graph': Graph(num_nodes=12, num_edges=24,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N13_E28_NL3_EL3_39', 'graph': Graph(num_nodes=13, num_edges=28,
          ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
          edata_schemes={}), 'label': tensor(1)}]
    8
    4
    176
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
  • 相关阅读:
    编程中常用的加密算法
    微信小程序关闭首页广告
    世界前沿技术发展报告2023《世界航空技术发展报告》(二)军用飞机技术
    GPC规范-SCP02
    照身帖、密钥,看古代人做实名认证有哪些招数?
    【Linux】之Centos7卸载KVM虚拟化服务
    接口自动化测试
    在Linux上使用yum安装MySQL
    国外JAVA相关学习网站
    医疗信息管理系统(HIS)——>业务介绍
  • 原文地址:https://blog.csdn.net/StarfishCu/article/details/126733544