• 【Pytorch】深度学习之数据读取


    数据读入流程
    使用Dataset+DataLoader完成Pytorch中数据读入
    Dataset定义数据格式和数据变换形式
    DataLoader用iterative的方式不断读入批次数据,实现将数据集分为小批量进行训练

    使用PyTorch自带数据集
    使用Dataset完成数据格式和数据变换的定义

    import torch
    from torchvision import datasets
    train_data = datasets.ImageFolder(train_path, transform=data_transform)
    val_data = datasets.ImageFolder(val_path, transform=data_transform)
    
    • 1
    • 2
    • 3
    • 4

    参数说明:
    transform实现对图像数据的变换处理

    使用DataLoader完成按批次读取数据

    from torch.utils.data import DataLoader
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
    
    • 1
    • 2
    • 3
    • 4

    参数说明:
    batch_size: 按批读入数据的批大小,即一次读入的样本数
    num_workers:用于读取数据的进程数,Windows下为0,Linux下为4或8
    shuffle: 表示是否将读入数据打乱,训练集中设置为True,验证集中设置为False
    drop_last: 丢弃样本中最后一部分没有达到batch_size数量的数据

    数据展示

    import matplotlib.pyplot as plt
    images, labels = next(iter(val_loader))
    print(images.shape)
    # 使用transpose()函数改变原始图像的表示形式,从(H,W,C)的表示转换为(C,H,W)的表示
    plt.imshow(images[0].transpose(1,2,0)) 
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    自定义数据集方式

    1. 自定义Dataset类继承Dataset
    2. 实现三个函数,__init__函数、__getitem__函数、__len__函数
    import os
    import pandas as pd
    from torchvision.io import read_image
    
    class MyDataset(Dataset):
        def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
            """
            Args:
                annotations_file (string): Path to the csv file with annotations.
                img_dir (string): Directory with all the images.
                transform (callable, optional): Optional transform to be applied on a sample.
                target_transform (callable, optional): Optional transform to be applied on the target.
            """
            self.img_labels = pd.read_csv(annotations_file)
            self.img_dir = img_dir
            self.transform = transform
            self.target_transform = target_transform
    
        def __len__(self):
            return len(self.img_labels)
    
        def __getitem__(self, idx):
            """
            Args:
                idx (int): Index
            """
            # 使用path.join()函数构建图像路径,img_labels.iloc[行,列]用于通过行列索引访问DataFrame中的元素
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) 
            image = read_image(img_path)
            label = self.img_labels.iloc[idx, 1]
            if self.transform:
                image = self.transform(image)
            if self.target_transform:
                label = self.target_transform(label)
            return image, label
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
  • 相关阅读:
    字符串 (3)--- KMP 算法的扩展
    HDFS架构设计理念以及优缺点
    js中this的原理详解(web前端开发javascript语法基础)
    Java 中经常被提到的 SPI 到底是什么?
    ajax day4
    微信小程序 | 游戏开发之接宝石箱子游戏
    预测评价指标
    电子元器件解析01——电阻
    无痛迁移:图解 Kubernetes 集群升级步骤
    DSPE-PEG-VIP,磷脂-聚乙二醇-血管活性肠肽VIP,VIP修饰脂质体供应
  • 原文地址:https://blog.csdn.net/m0_61819793/article/details/133747604