若label作为文件夹名字,图片存放在里面:
- from torch.utils.data import Dataset
- import os
- import cv2 as cv
-
-
-
- class MyData(Dataset):
-
- def __init__(self,root_dir,label_dir):
- self.root_dir = root_dir #读取训练集的路径
- self.label_dir = label_dir #读取label的名字
- self.path = os.path.join(self.root_dir,self.label_dir) #这是图片的路径
- self.img_path = os.listdir(self.path) #把一条条路径变成列表
-
-
- def __getitem__(self, idx):
- img_name = self.img_path[idx]
- img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
- img = cv.imread(img_item_path)
- label = self.label_dir
- return img,label
- def __len__(self):
- return len(self.img_path)
-
- root_dir = "dataset/train1"
- ants_label_dir = "ants"
- bees_label_dir = "bees"
- ants_dataset = MyData(root_dir,ants_label_dir)
- bees_dataset = MyData(root_dir,bees_label_dir)
-
- train_dataset = ants_dataset + bees_dataset;
-
若label和图片分别在不同文件夹:
-
- from torch.utils.data import Dataset
- import os
- import cv2 as cv
-
-
- class MyData(Dataset):
-
- def __init__(self,root_dir,img_dir,label_dir):
- self.root_dir = root_dir #读取训练集的路径
- self.label_dir = label_dir #读取label的名字
- self.img_dir = img_dir
-
- self.img_path = os.path.join(self.root_dir,self.img_dir) #这是图片的路径
- self.label_path = os.path.join(self.root_dir, self.label_dir) # 这是标签的路径
-
- self.img_path_list = os.listdir(self.img_path) #把一条条路径变成列表
- self.label_path_list = os.listdir(self.label_path) # 把一条条路径变成列表
-
-
- def __getitem__(self, idx):
- img_name = self.img_path_list[idx]
- img_item_path = os.path.join(self.img_path,img_name)
- img = cv.imread(img_item_path)
- label_name = self.label_path_list[idx]
- label_item_path = os.path.join(self.label_path, label_name)
- label = open(label_item_path,encoding = 'utf-8')
- content = label.read()
- return img,content
- def __len__(self):
- return len(self.img_path_list)
-
- root_dir = "dataset2/train"
- label_dir = "ants_label"
- img_dir = "ants_image"
- ants_dataset = MyData(root_dir,img_dir,label_dir)
-
-
-