• paddle 自定义数据集和预处理


    自定义数据集 

    1. import paddle
    2. from matplotlib import pyplot as plt
    3. import os
    4. import cv2
    5. import numpy as np
    6. from paddle.io import Dataset
    7. from paddle.vision.transforms import Normalize
    8. print('计算机视觉(CV)相关数据集:', paddle.vision.datasets.__all__)
    9. print('自然语言处理(NLP)相关数据集:', paddle.text.__all__)
    10. #图像数据集
    11. # vis_dataset=paddle.vision.datasets.MNIST(mode='train',transform=paddle.vision.transforms.ToTensor())
    12. # print(len(vis_dataset))
    13. # image,label=vis_dataset[0]
    14. # print(type(image))
    15. # print(image.shape)
    16. # print(label)
    17. #文字数据集
    18. # text_dataset=paddle.text.Imdb()
    19. # text,label=text_dataset[1]
    20. # print(type(text))
    21. #
    22. # print(label)
    23. #图像显示
    24. # for data in vis_dataset:
    25. # image,label=data
    26. # print('图片的shape',image.shape)
    27. # plt.title(str(label))
    28. # plt.imshow(image[0])
    29. # plt.show()
    30. #自定义数据集
    31. class MyDataset(Dataset):
    32. """
    33. 步骤一:继承 paddle.io.Dataset 类
    34. """
    35. def __init__(self, data_dir, label_path, transform=None):
    36. """
    37. 步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
    38. """
    39. super().__init__()
    40. self.data_list = []
    41. with open(label_path,encoding='utf-8') as f:
    42. for line in f.readlines():
    43. image_path, label = line.strip().split('\t')
    44. image_path = os.path.join(data_dir, image_path)
    45. self.data_list.append([image_path, label])
    46. # 传入定义好的数据处理方法,作为自定义数据集类的一个属性
    47. self.transform = transform
    48. def __getitem__(self, index):
    49. """
    50. 步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
    51. """
    52. # 根据索引,从列表中取出一个图像
    53. image_path, label = self.data_list[index]
    54. # 读取灰度图
    55. image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    56. # 飞桨训练时内部数据格式默认为float32,将图像数据格式转换为 float32
    57. image = image.astype('float32')
    58. # 应用数据处理方法到图像上
    59. if self.transform is not None:
    60. image = self.transform(image)
    61. # CrossEntropyLoss要求label格式为int,将Label格式转换为 int
    62. label = int(label)
    63. # 返回图像和对应标签
    64. return image, label
    65. def __len__(self):
    66. """
    67. 步骤四:实现 __len__ 函数,返回数据集的样本总数
    68. """
    69. return len(self.data_list)
    70. # 定义图像归一化处理方法,这里的CHW指图像格式需为 [C通道数,H图像高度,W图像宽度]
    71. transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
    72. # 打印数据集样本数
    73. train_custom_dataset = MyDataset('mnist/train','mnist/train/label.txt', transform)
    74. print('train_custom_dataset images: ',len(train_custom_dataset))
    75. train_dataloader=paddle.io.DataLoader(dataset=train_custom_dataset,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
    76. #迭代dataloader并且显示图片
    77. for batch_data in train_dataloader:
    78. image, label = batch_data
    79. col=10
    80. row=7
    81. for i in range(len(image)):
    82. plt.subplot(row,col ,i+1)
    83. plt.title(str(label[i].numpy()))
    84. plt.xticks([])
    85. plt.yticks([])
    86. plt.imshow(image[i][0])
    87. plt.show()
    88. break
    89. #batchsamper
    90. # from paddle.io import BatchSampler
    91. #
    92. # bs=BatchSampler(train_custom_dataset,batch_size=8,shuffle=True,drop_last=True)
    93. # print('batchsamer每轮返回一个索引列表')
    94. # for batch_indices in bs:
    95. # print(batch_indices)
    96. # break

    数据预处理

    1. import cv2
    2. import numpy as np
    3. from PIL import Image
    4. from matplotlib import pyplot as plt
    5. from paddle.vision.transforms import CenterCrop,RandomHorizontalFlip,Compose,ColorJitter
    6. transform =Compose([CenterCrop(20),
    7. RandomHorizontalFlip(0.5),#基于概率来执行图片的水平翻转
    8. ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)#随机调整图像的亮度、对比度、饱和度和色调
    9. ])
    10. #1、opencv读取图片
    11. # image = cv2.imread('0.jpg')
    12. # print(image.shape)
    13. # image_after_transform = transform(image)
    14. # print(image_after_transform.shape)
    15. # plt.subplot(1,2,1)
    16. # plt.title('origin image')
    17. # plt.imshow(image[:,:,::-1])
    18. # plt.subplot(1,2,2)
    19. # plt.title('transform image')
    20. # plt.imshow(image_after_transform[:,:,::-1])
    21. # plt.show()
    22. # 2、PIL读取图片
    23. image=Image.open('0.jpg')
    24. image_after_transform=transform(image)
    25. plt.subplot(1,2,1)
    26. plt.title('origin image')
    27. plt.imshow(image)
    28. plt.subplot(1,2,2)
    29. plt.title('transform image')
    30. plt.imshow(image_after_transform)
    31. plt.show()

  • 相关阅读:
    【STL巨头】set、map、multiset、multimap的介绍及使用
    Scss--@mixin--使用/实例
    恒星的正方形问题
    聊一下C#中的lock
    springboot校园失物招领网站系统在线视频点播系统毕业设计毕设作品开题报告开题答辩PPT
    TikTok 推荐引擎强大的秘密
    SpringBoot SpringBoot 开发实用篇 4 数据层解决方案 4.11 SpringBoot 整合 MongoDB
    实战 | 基于YOLOv10的车辆追踪与测速实战【附源码+步骤详解】
    母婴行业探秘:千万级会员体量下的精准营销
    接口幂等性(防止接口重复提交)
  • 原文地址:https://blog.csdn.net/qq_40107571/article/details/134077206