collate_fn
函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn
,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。
- import torch
- from torch.utils.data import Dataset, DataLoader
- import numpy as np
-
- class MyDataset(Dataset):
- def __init__(self, imgs, labels):
- self.imgs = imgs
- self.labels = labels
-
- def __len__(self):
- return len(self.imgs)
-
- def __getitem__(self, idx):
- img = self.imgs[idx]
- out_img = img.astype(np.float32)
- out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c -->> c,h,w
- out_label = self.labels[idx] #[4, 5] or [2, 5]
- return out_img, out_label
-
- #if batchsize=3
- #batch is list, [3]
- #batch0 tuple2 (np[3, 300, 150], np[4, 5])
- #batch1 tuple2 (np[3, 300, 150], np[2, 5])
- #batch2 tuple2 (np[3, 300, 150], np[4, 5])
- def my_collate_fn(batch):
- """Custom collate fn for dealing with batches of images that have a different
- number of associated object annotations (bounding boxes).
- Arguments:
- batch: (tuple) A tuple of tensor images and lists of annotations
- Return:
- A tuple containing:
- 1) (tensor) batch of images stacked on their 0 dim
- 2) (list of tensors) annotations for a given image are stacked on
- 0 dim
- """
- targets = []
- imgs = []
- for sample in batch:
- imgs.append(torch.FloatTensor(sample[0]))
- targets.append(torch.FloatTensor(sample[1]))
-
- imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
- return imgs_out, targets
-
-
-
-
- img_data = []
- label_data = []
-
- nums = 34
- H=300
- W=150
- for _ in range(nums):
- random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
- nums_target = np.random.randint(low=0, high=10)
- random_xyxy_label = np.random.random((nums_target, 5))
- img_data.append(random_img)
- label_data.append(random_xyxy_label)
-
- dataset = MyDataset(img_data, label_data)
- dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)
-
- for cnt, (img, label) in enumerate(dataloader):
- print("==>>", cnt, ", img shape=", img.shape)
- for i in range(len(label)):
- print("label shape=", label[i].shape)
打印如下:
- ==>> 0 , img shape= torch.Size([3, 3, 300, 150])
- label shape= torch.Size([8, 5])
- label shape= torch.Size([2, 5])
- label shape= torch.Size([5, 5])
- ==>> 1 , img shape= torch.Size([3, 3, 300, 150])
- label shape= torch.Size([3, 5])
- label shape= torch.Size([8, 5])
- label shape= torch.Size([5, 5])
- ==>> 2 , img shape= torch.Size([3, 3, 300, 150])
- label shape= torch.Size([7, 5])
- label shape= torch.Size([1, 5])
- label shape= torch.Size([8, 5])