• pytorch collate_fn测试用例


    collate_fn 函数用于处理数据加载器(DataLoader)中的一批数据。在PyTorch中使用 DataLoader 时,通过设置collate_fn,我们可以决定如何将多个样本数据整合到一起成为一个 batch。在某些情况下,该函数需要由用户自定义以满足特定需求。

    1. import torch
    2. from torch.utils.data import Dataset, DataLoader
    3. import numpy as np
    4. class MyDataset(Dataset):
    5. def __init__(self, imgs, labels):
    6. self.imgs = imgs
    7. self.labels = labels
    8. def __len__(self):
    9. return len(self.imgs)
    10. def __getitem__(self, idx):
    11. img = self.imgs[idx]
    12. out_img = img.astype(np.float32)
    13. out_img = out_img.transpose(2, 0, 1) #[3, 300, 150]h,w,c -->> c,h,w
    14. out_label = self.labels[idx] #[4, 5] or [2, 5]
    15. return out_img, out_label
    16. #if batchsize=3
    17. #batch is list, [3]
    18. #batch0 tuple2 (np[3, 300, 150], np[4, 5])
    19. #batch1 tuple2 (np[3, 300, 150], np[2, 5])
    20. #batch2 tuple2 (np[3, 300, 150], np[4, 5])
    21. def my_collate_fn(batch):
    22. """Custom collate fn for dealing with batches of images that have a different
    23. number of associated object annotations (bounding boxes).
    24. Arguments:
    25. batch: (tuple) A tuple of tensor images and lists of annotations
    26. Return:
    27. A tuple containing:
    28. 1) (tensor) batch of images stacked on their 0 dim
    29. 2) (list of tensors) annotations for a given image are stacked on
    30. 0 dim
    31. """
    32. targets = []
    33. imgs = []
    34. for sample in batch:
    35. imgs.append(torch.FloatTensor(sample[0]))
    36. targets.append(torch.FloatTensor(sample[1]))
    37. imgs_out = torch.stack(imgs, 0) #[3, 3, 300, 150]
    38. return imgs_out, targets
    39. img_data = []
    40. label_data = []
    41. nums = 34
    42. H=300
    43. W=150
    44. for _ in range(nums):
    45. random_img = np.random.randint(low=0, high=255, size=(H, W, 3))
    46. nums_target = np.random.randint(low=0, high=10)
    47. random_xyxy_label = np.random.random((nums_target, 5))
    48. img_data.append(random_img)
    49. label_data.append(random_xyxy_label)
    50. dataset = MyDataset(img_data, label_data)
    51. dataloader = DataLoader(dataset, batch_size=3, collate_fn=my_collate_fn)
    52. for cnt, (img, label) in enumerate(dataloader):
    53. print("==>>", cnt, ", img shape=", img.shape)
    54. for i in range(len(label)):
    55. print("label shape=", label[i].shape)

    打印如下:

    1. ==>> 0 , img shape= torch.Size([3, 3, 300, 150])
    2. label shape= torch.Size([8, 5])
    3. label shape= torch.Size([2, 5])
    4. label shape= torch.Size([5, 5])
    5. ==>> 1 , img shape= torch.Size([3, 3, 300, 150])
    6. label shape= torch.Size([3, 5])
    7. label shape= torch.Size([8, 5])
    8. label shape= torch.Size([5, 5])
    9. ==>> 2 , img shape= torch.Size([3, 3, 300, 150])
    10. label shape= torch.Size([7, 5])
    11. label shape= torch.Size([1, 5])
    12. label shape= torch.Size([8, 5])
  • 相关阅读:
    Java“牵手”京东商品详情数据,京东商品详情API接口,京东API接口申请指南
    制作电子画册的有好帮手---FLBOOK
    Idea加载gradle项目问题小记Gradle‘s dependency cache may be corrupt
    Android BottomSheet总结
    翻页视图ViewPager
    low power-upf-vcsnlp(四)
    精分合并抑郁康复经历分享:如何从死亡边缘回到生的海洋?
    数据结构之栈:使用栈数据结构实现字符串中相邻两个字符不重复
    R语言ggplot2和gganimate包可视化动态动画气泡图(Animated Bubble chart):使用gganimate包创建可视化gif动图
    android pdf框架-4,分析vudroid源码2
  • 原文地址:https://blog.csdn.net/yang332233/article/details/134182432