数据读入流程
使用Dataset+DataLoader完成Pytorch中数据读入
Dataset定义数据格式和数据变换形式
DataLoader用iterative的方式不断读入批次数据,实现将数据集分为小批量进行训练
使用PyTorch自带数据集
使用Dataset完成数据格式和数据变换的定义
import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
参数说明:
transform实现对图像数据的变换处理
使用DataLoader完成按批次读取数据
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
参数说明:
batch_size: 按批读入数据的批大小,即一次读入的样本数
num_workers:用于读取数据的进程数,Windows下为0,Linux下为4或8
shuffle: 表示是否将读入数据打乱,训练集中设置为True,验证集中设置为False
drop_last: 丢弃样本中最后一部分没有达到batch_size数量的数据
数据展示
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
# 使用transpose()函数改变原始图像的表示形式,从(H,W,C)的表示转换为(C,H,W)的表示
plt.imshow(images[0].transpose(1,2,0))
plt.show()
自定义数据集方式
Dataset
类__init__
函数、__getitem__
函数、__len__
函数import os
import pandas as pd
from torchvision.io import read_image
class MyDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
"""
Args:
annotations_file (string): Path to the csv file with annotations.
img_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
target_transform (callable, optional): Optional transform to be applied on the target.
"""
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
"""
Args:
idx (int): Index
"""
# 使用path.join()函数构建图像路径,img_labels.iloc[行,列]用于通过行列索引访问DataFrame中的元素
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label