• pytorch collate_fn测试用例


    collate_fn 函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。

    1. import torch
    2. from torch.utils.data import Dataset, DataLoader
    3. import numpy as np
    4. class MyDataset(Dataset):
    5. def __init__(self, imgs, labels):
    6. self.imgs = imgs
    7. self.labels = labels
    8. def __len__(self):
    9. return len(self.imgs)
    10. def __getitem__(self, idx):
    11. img = self.imgs[idx]
    12. out_img = img.astype(np.float32)
    13. out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c -->> c,h,w
    14. out_label = self.labels[idx] #[4, 5] or [2, 5]
    15. return out_img, out_label
    16. #if batchsize=3
    17. #batch is list, [3]
    18. #batch0 tuple2 (np[3, 300, 150], np[4, 5])
    19. #batch1 tuple2 (np[3, 300, 150], np[2, 5])
    20. #batch2 tuple2 (np[3, 300, 150], np[4, 5])
    21. def my_collate_fn(batch):
    22. """Custom collate fn for dealing with batches of images that have a different
    23. number of associated object annotations (bounding boxes).
    24. Arguments:
    25. batch: (tuple) A tuple of tensor images and lists of annotations
    26. Return:
    27. A tuple containing:
    28. 1) (tensor) batch of images stacked on their 0 dim
    29. 2) (list of tensors) annotations for a given image are stacked on
    30. 0 dim
    31. """
    32. targets = []
    33. imgs = []
    34. for sample in batch:
    35. imgs.append(torch.FloatTensor(sample[0]))
    36. targets.append(torch.FloatTensor(sample[1]))
    37. imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
    38. return imgs_out, targets
    39. img_data = []
    40. label_data = []
    41. nums = 34
    42. H=300
    43. W=150
    44. for _ in range(nums):
    45. random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
    46. nums_target = np.random.randint(low=0, high=10)
    47. random_xyxy_label = np.random.random((nums_target, 5))
    48. img_data.append(random_img)
    49. label_data.append(random_xyxy_label)
    50. dataset = MyDataset(img_data, label_data)
    51. dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)
    52. for cnt, (img, label) in enumerate(dataloader):
    53. print("==>>", cnt, ", img shape=", img.shape)
    54. for i in range(len(label)):
    55. print("label shape=", label[i].shape)

    打印如下:

    1. ==>> 0 , img shape= torch.Size([3, 3, 300, 150])
    2. label shape= torch.Size([8, 5])
    3. label shape= torch.Size([2, 5])
    4. label shape= torch.Size([5, 5])
    5. ==>> 1 , img shape= torch.Size([3, 3, 300, 150])
    6. label shape= torch.Size([3, 5])
    7. label shape= torch.Size([8, 5])
    8. label shape= torch.Size([5, 5])
    9. ==>> 2 , img shape= torch.Size([3, 3, 300, 150])
    10. label shape= torch.Size([7, 5])
    11. label shape= torch.Size([1, 5])
    12. label shape= torch.Size([8, 5])
  • 相关阅读:
    华为OD机考算法题:服务器广播
    七、鼎捷T100采购应付管理流程
    Go 语言编译环境
    Jenkins离线插件配置(二)
    【开发工具】gitee还不用会?我直接拿捏 >_>
    linux环境下的MySQL UDF提权
    《深入理解计算机系统》笔记
    Typescript给定一个由key值组成的数组keys,返回由数组项作为key值组成的对象
    论文阅读之《Learn to see in the dark》
    C++模板
  • 原文地址:https://blog.csdn.net/yang332233/article/details/134182432