在目标检测中,我们经常会用到Dataset这个类,它主要是由三个函数构成
def __init__(self, list_IDs, labels):
'Initialization'
self.labels = list_anno_path
self.images = list_image_path
def __len__(self):
'Denotes the total number of samples'
return len(self.labels)
def __getitem__(self, index):
'Generates one sample of data'
# Select sample
image = read(self.images[index])
# Load data and get label
bbox = parse_anno(self.labels(index))
return image, bbox
对于分类没有问题,每张图片只对应一个类别,但是对于检测来说,每张图片对应的目标数量不一致,就会导致在组batch时,尺寸不一致的问题。这时候我们就要用到自定义的collate_fn
loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=custom_collate)
参考yolov3对于collate_fn的设计,只需要在bbox中增加一维来标注bbox属于第几张图片即可,这样返回images是四维,bboxes是两维
def my_collate(batch):
images = []
bboxes = []
for i, (img, box) in enumerate(batch):
images.append(img)
box[:, 0] = i
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
bboxes = torch.from_numpy(np.concatenate(bboxes, 0)).type(torch.FloatTensor)
return images, bboxes