针对DGL数据集做few-shot问题时候,需要将数据集划分成nshot的train,val,test。要求划分出多个task,每个task的train,val,test比例一致。同一个task内部,train,val,test无交叉;task间,train,val,test可以交叉。
以MUTAG为例:
{'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(0)}
label代表graph的类别
#按照label做升序
def cmp(a,b):
if a['label']<b['label']:
return -1
if a['label']>b['label']:
return 1
return 0
def few_shot_split_graphlevl(dataset,train_shotnum,val_shotnum,classnum,tasknum):
#task中的train,val,test无交叉;但不同task之间的train,val,test可以交叉
#train_shotnum代表train中的shotnum,val_shotnum同理
#classnum代表总共有几类
#先将dataset中的数据按照class排序,然后随机从中选出数据来放入train val,剩下的放进test即可
train=[]
val=[]
test=[]
dataset=sorted(dataset,key=functools.cmp_to_key(cmp))
length=len(dataset)
#统计每类各有多少张图
classcount=torch.zeros(classnum)
#统计每类的第一个元素在dataset中的索引位置
class_start_index=torch.zeros(classnum)
label_before=1e6
count=0
for data in dataset:
classcount[data['label']]+=1
if label_before != data['label']:
label_before=data['label']
class_start_index[data['label']]=count
count+=1
#print('classcount:',classcount)
#print(class_start_index)
class_start_index=class_start_index.int()
for task in range(tasknum):
train_index=[]
val_index=[]
test_index=list(range(0,length))
for c in range(classnum):
if c!=classnum-1:
index=random.sample(range(class_start_index[c],class_start_index[c+1]-1), train_shotnum+val_shotnum)
else:
#从每类中的元素中选出train_shotnum+val_shotnum的元素,再将这些元素按照所需数目分别加入train和val中
index=random.sample(range(class_start_index[c],length-1), train_shotnum+val_shotnum)
train_index=train_index+index[0:train_shotnum]
val_index=val_index+index[train_shotnum:len(index)]
train.append([dataset[i]for i in train_index])
val.append([dataset[i]for i in val_index])
#剩下的元素全部加入test
train_val_index=train_index+val_index
train_val_index.sort(reverse=True)
for i in train_val_index:
test_index.pop(i)
test.append([dataset[i]for i in test_index])
return train,val,test
dataset = GraphAdjDataset(list())
dataset.load(os.path.join(train_config["save_data_dir"], "train_dgl_dataset.pt"))
print(dataset[0])
print(dataset[0]['label'])
trainset,valset,testset=few_shot_split_graphlevl(dataset,4,2,2,2)
print(trainset[0])
print(len(trainset[1]))
print(len(valset[1]))
print(len(testset[1]))
>>>{'id': 'G_N21_E48_NL3_EL3_121', 'graph': Graph(num_nodes=21, num_edges=48,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(1)}
tensor(1)
[{'id': 'G_N22_E50_NL3_EL3_133', 'graph': Graph(num_nodes=22, num_edges=50,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N17_E38_NL3_EL3_47', 'graph': Graph(num_nodes=17, num_edges=38,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N22_E50_NL3_EL3_162', 'graph': Graph(num_nodes=22, num_edges=50,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N11_E22_NL3_EL3_131', 'graph': Graph(num_nodes=11, num_edges=22,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(0)}, {'id': 'G_N16_E34_NL3_EL3_145', 'graph': Graph(num_nodes=16, num_edges=34,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N13_E28_NL3_EL3_1', 'graph': Graph(num_nodes=13, num_edges=28,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N12_E24_NL3_EL3_80', 'graph': Graph(num_nodes=12, num_edges=24,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(1)}, {'id': 'G_N13_E28_NL3_EL3_39', 'graph': Graph(num_nodes=13, num_edges=28,
ndata_schemes={'indeg': Scheme(shape=(), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'id': Scheme(shape=(), dtype=torch.int64), 'sample': Scheme(shape=(3,), dtype=torch.int64)}
edata_schemes={}), 'label': tensor(1)}]
8
4
176