假设数据目录结构是data_dir/images包含图像文件,data_dir/labels包含对应的标签文件,并且图像和标签的文件名是匹配的。
- import torch
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms
- from PIL import Image
- import os
-
- # 定义一个名为CustomDataset的类,继承自torch.utils.data.Dataset,用于自定义数据集
- class CustomDataset(Dataset):
- def __init__(self, data_dir, transform=None):
- # 存储数据集的目录路径
- self.data_dir = data_dir
- # 存储图像和标签的预处理/变换操作
- self.transform = transform
-
- # 获取数据目录下"images"文件夹中的所有图像文件名,并存储在self.images列表中
- self.images = os.listdir(os.path.join(data_dir, "images"))
- # 获取数据目录下"labels"文件夹中的所有标签文件名,并存储在self.labels列表中
- self.labels = os.listdir(os.path.join(data_dir, "labels"))
- # 注意:这里假设图像和标签的文件名是一一对应的
-
- # 定义__len__方法,返回数据集的大小
- def __len__(self):
- # 返回self.images列表的长度,即图像的数量
- return len(self.images)
-
- # 定义__getitem__方法,根据索引idx返回一个数据样本(图像+对应的标签)
- def __getitem__(self, idx):
- # 根据索引idx从self.images和self.labels列表中获取图像和标签的文件名,并拼接成完整的文件路径
- image_path = os.path.join(self.data_dir, "images", self.images[idx])
- label_path = os.path.join(self.data_dir, "labels", self.labels[idx])
-
- # 使用PIL库加载图像文件,并将其转换为RGB格式(三通道彩色图像)
- image = Image.open(image_path).convert('RGB')
- # 使用PIL库加载标签文件,并将其转换为L格式(单通道灰度图像),这里假设标签是灰度图
- label = Image.open(label_path).convert('L')
-
- # 如果定义了预处理/变换操作,则对图像和标签应用这些操作
- # 注意:在实际应用中,图像和标签可能需要不同的预处理/变换操作
- if self.transform:
- image = self.transform(image)
- label = self.transform(label)
-
- # 返回变换后的图像和标签作为一个数据样本
- return image, label
接下来,我们使用CustomDataset类来创建训练集和数据加载器(DataLoader):
- # 定义变换
- transform = transforms.Compose([
- transforms.Resize((64, 64)), # 调整图像大小到64x64
- transforms.ToTensor(), # 将PIL图像转换为tensor
- # 添加其他必要的变换...
- ])
-
- # 创建训练集实例
- train_dataset = CustomDataset(data_dir="path_to_your_data", transform=transform)
-
- # 创建数据加载器
- train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
使用训练集来训练模型:
- # 定义你的模型
- model = ...
-
- # 定义损失函数和优化器
- criterion = ...
- optimizer = ...
-
- # 训练模型
- num_epochs = 10 # 设置训练的epoch数量
- for epoch in range(num_epochs):
- for images, labels in train_loader:
- # 将数据发送到设备(CPU或GPU)上
- images, labels = images.to(device), labels.to(device)
-
- # 前向传播
- outputs = model(images)
-
- # 计算损失
- loss = criterion(outputs, labels)
-
- # 反向传播和优化
- optimizer.zero_grad() # 清空之前的梯度
- loss.backward() # 反向传播,计算当前梯度
- optimizer.step() # 更新权重
-
- # 打印统计信息
- print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')