• dataset.py篇


    dataset.py

    目录:

    • 前言
    • 观察数据
    • 书写代码
    • 函数解释

    前言

    在步骤中需要写自己的dataset类,并将label和image一一对应后返回。

    观察数据

    在书写dataset前最重要的就是要观察数据集,对数据集进行分析,比如了解图片大小,通道数目,他的ndarray的dtype类型等等。甚至可以自己书写一个脚本,对数据本身进行分析。

    如下以果蝇电镜图为例,我们观察数据集,知道Images中存放30张原图,Labels中存放30张已经分割好的图片,每张图片以.png方式进行存储;Images和Labels中的图片通过文件名进行一一对应,如Labels中0.png对应Images中0.png,进行共有30组对应数据;根据model.py文件了解需要传入的Tensor类型,思考如何将现有数据转为需要的Tensor返回。

    请添加图片描述

    知道以上信息后,我们书写dataset.py将image和label一一对应返回。

    书写代码

    在本步骤中,我们需要告诉程序如何读入你的数据,并且做一些预处理。我们的DIYDataset类继承自Dataset类,并重写__init____len____getitem__三个魔法方法。网络训练权重为float32,所以传入数据也一般要为float32。

    1. __init__方法向外索取3个输入数据:
    • 读取数据路径
    • 是训练集还是验证集(因为你训练集和验证集往往是在两个不同的文件夹)
    • 你使用的预处理方法(以transform为主,transform也可以根据验证集还是测试集调用不同的trans)
    1. __getitem__需要返回image和对应label的Tensor
    2. __len__用来返回集合中图片个数

    以S-BIAD25为例,代码如下:

    '''
    该类用于返回dataset
    input和label都是[512, 512]的图片,需要将其转换为[3, 512, 512]才能transforms
    '''
    # --- add path and DIY package
    import sys, os
    root_path = os.path.dirname(os.path.dirname(__file__))
    project_path = os.path.dirname(__file__)
    sys.path.append(root_path)
    sys.path.append(project_path)
    from _utils import tensor_info  # 耦合了
    # ---
    from torch.utils.data.dataset import Dataset
    from PIL import Image
    from utils.utils import cvtColor,resize_image
    from torchvision import transforms
    
    
    class CellDataset(Dataset):
        """返回细胞分隔的dataset"""
    
        def __init__(self, path:str, transforms:object=None) -> None:
            super().__init__()
            self.path       = path
            self.labels     = os.listdir(os.path.join(path,'Labels'))
            self.transforms = transforms
            
    
        def __len__(self):
            return len(self.labels)
    
        def __getitem__(self, index):
            yield_name  = self.labels[index]
            # 返回label图片
            label_path  = os.path.join(self.path, 'Labels', yield_name)
            label_image = Image.open(label_path)
            label_image = cvtColor(label_image)                     # 将单通道灰度图片, 四通道png转换为三通道RGB图片。因为此处torchvision.transforms只接受三通道图片,pytorch中模型一般只能训练dtype=float32的Tensor。
            label_image = resize_image(label_image,(512, 512))[0]   # 裁剪图片大小为512 * 512
            # 返回训练图片
            image_path  = os.path.join(self.path, 'Images', yield_name)
            image       = Image.open(image_path)
            image       = cvtColor(image)
            image       = resize_image(image,(512, 512))[0]
            # 返回单张input, label
            return self.transforms(image), self.transforms(label_image)
    
    
    class MouseDataset(Dataset):
        """返回细胞分隔的dataset"""
    
        def __init__(self, path) -> None:
            super().__init__()
            self.path       = path
            self.f_actin    = os.path.join(path, "F-actin") 
            self.labels_factin = os.listdir(self.f_actin)
            # self.labels     = os.listdir(os.path.join(path,'Labels'))
            self.transforms = transforms.Compose([transforms.ToTensor()])
    
        def __len__(self):
            return len(self.labels_factin)
    
        def __getitem__(self, index):
            yield_name  = self.labels_factin[index][-19:]
            # 返回label图片
            label_path  = os.path.join(self.path, 'F-actin', "img_568"+yield_name)
            label_image = Image.open(label_path)
            label_image = cvtColor(label_image)                     # 将单通道灰度图片, 四通道png转换为三通道RGB图片。因为此处torchvision.transforms只接受三通道图片,pytorch中模型一般只能训练dtype=float32的Tensor。
            label_image = resize_image(label_image,(512, 512))[0]   # 裁剪图片大小为512 * 512
            # 返回训练图片
            image_path  = os.path.join(self.path, 'retardance', "img_Retardance"+yield_name)
            image       = Image.open(image_path)
            image       = cvtColor(image)
            image       = resize_image(image,(512, 512))[0]
            # 返回单张input, label
            return self.transforms(image), self.transforms(label_image)
    
    
    if __name__ == "__main__":
        """test"""
        data_path = "/home/yingmuzhi/_data/S-BIAD25"
        cell_dataset = MouseDataset(data_path)
    
        input, label = cell_dataset[0]  # 同 input, label = cell_dataset.__getitem__(0)
        tensor_info.TensorInfo(input).show_info()
        tensor_info.TensorInfo(label).show_info()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85

    测试结果

    我们在if __name__ == "__main__":中进行测试,结果如下:

    请添加图片描述

    函数解释

    对一些常用函数进行解释,可以当做字典查看

    torchvision.transforms.ToTensor

    只接受PIL.Image类型的对象或者numpy.ndarray类型的对象,将上面两个对象转为torch.Tensor对象。函数定义如下:

    在这里插入图片描述

    参考链接https://pytorch.org/vision/stable/generated/torchvision.transforms.ToTensor.html?highlight=totensor#torchvision.transforms.ToTensor

    __init__

    在该方法中需要根据传入的path参数和train参数找到你的测试集或者训练集的物理地址,并将集合中的images和labels的物理地址存储在list中,以供后面方法使用。案例如下:

    def __init__(self, root: str, train: bool, transforms: object=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test" # 由 train: bool 的布尔值来判断是取train还是test
        data_root = os.path.join(root, "DRIVE", self.flag)
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]               # 返回所有images地址的list
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 返回所有manuals地址的list
                       for i in img_names]
        # check manual files
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")    # 返回所有mask地址的list
                         for i in img_names]
        # check mask files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    __len__

    返回测试集或者验证集中Image数量,因为Image和Label往往是一一对应的,所以返回哪个其实都一样。三维数据可能存在多对一情况。

    案例如下:

    def __len__(self):
    	return len(self.img_list)
    
    • 1
    • 2

    __getitem__

    该方法根据image和label的物理地址,用PIL打开图片,再用transforms处理Image返回Tensor,最后返回处理过的Tensor类型元组(image, label)。

    在该方法中,你可以使用PIL处理图像(mode),也可以将PIL转为numpy使用numpy处理图片(元素类型dtype),也可以使用Transforms处理图片(Normalization)等。

    __getitem__案例见下:

    def __getitem__(self, idx):
        """将Image转为RGB, 将label转为L"""
        img = Image.open(self.img_list[idx]).convert('RGB')
        manual = Image.open(self.manual[idx])
        manual = manual.convert('L')
        manual = np.array(manual) / 255
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')
        roi_mask = 255 - np.array(roi_mask)
        # 将manual图片和Imae进行处理
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)
        if self.transforms is not None:
            img, mask = self.transforms(img, mask)
        return img, mask
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    使用PIL处理

    参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

    使用Numpy处理

    参看链接https://blog.csdn.net/qq_43369406/article/details/127781871

    使用transforms处理

    参看链接[coming soon]

    这段代码我们常写在train.py中。在进行transforms累加时候,我们常将所需要的transforms全部添加至一个list中,再将这个list给transforms.Compose掉,注意一定要添加transforms.ToTensor方法。transform的更多内容可以参考笔者的transforms博客。

    在调用transforms时完整逻辑如下:

    # 获取dataset
    train_dataset = DriveDataset(args.data_path,
                                     train=True,
                                     transforms=get_transform(train=True, mean=mean, std=std))
    # 获取tranforms
    def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    	"""对要获取的transforms进行判断,看是测试集的dataset还是验证集"""
        base_size = 565
        crop_size = 480
    
        if train:
            return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
        else:
            return SegmentationPresetEval(mean=mean, std=std)
    # 测试集transforms
    class SegmentationPresetTrain:
        def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                     mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
            min_size = int(0.5 * base_size)
            max_size = int(1.2 * base_size)
    
            trans = [T.RandomResize(min_size, max_size)]
            if hflip_prob > 0:
                trans.append(T.RandomHorizontalFlip(hflip_prob))
            if vflip_prob > 0:
                trans.append(T.RandomVerticalFlip(vflip_prob))
            trans.extend([
                T.RandomCrop(crop_size),
                T.ToTensor(),
                T.Normalize(mean=mean, std=std),
            ])
            self.transforms = T.Compose(trans)
    
        def __call__(self, img, target):
            return self.transforms(img, target)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    增加了魔法方法__call__只是为了对image/input,label/target一块进行transforms。

    torchvision.transforms.Compose()和torchvision.transforms.[functions]

    Compose()类可以将多个transforms对象合在一起给数据进行预处理,常以多个transforms对象的list形式传入Compose()中,如transforms.Compose([])。transforms.[functions]则可以对多个输入数据进行变换,Compose()函数原型如下:

    # 原型
    torchvision.transforms.Compose(transforms)
    # example
    data_transform = {
        "train": transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        "val": transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    torchvision.datasets.ImageFolder()类

    ImageFolder()用于加载数据集,它说到底还是继承自torch.utils.data.Dataset,后续可以作为Dataset对象直接传入torch.utils.data.DataLoader中。对于ImageFolder(),root是要加载数据集的路径,transforms是对数据进行预处理的方式,函数原型和example如下:

    # 原型
    torchvision.datasets.ImageFolder(root: str, transform: Optional[Callable] = None, 
    									target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = <function default_loader>, is_valid_file: Optional[Callable[[str], bool]] = None)
    # example
    train_dataset = datasets.ImageFolder(root=project_path + "/flower_data/train", transform=data_transform["train"])
    
    • 1
    • 2
    • 3
    • 4
    • 5

    python 列表生成式和生成器

    python 常见列表生成式 和 生成器,其中列表生成式以[]圈起,生成器以()圈起,常见列表生成器如下:

    在这里插入图片描述

    而生成器则是把中括号改成小括号。

    # 列表生成式
    generate_list = [i for i in range(10)  if i < 5]
    # 生成器
    cla_dict = dict((val, key) for key, val in flower_list.items())	# 解包取出键值对,生成dict字典,赋值给cla_dict
    
    • 1
    • 2
    • 3
    • 4
  • 相关阅读:
    证书格式说明
    AWS IAM User assume IAM Role的示例代码
    pandas使用str函数和contains函数删除dataframe中单个指定字符串数据列包含特定字符串列表中的其中任何一个字符串的数据行
    Spring详解及具体使用
    Autojs 利用OpenCV识别棋子之天天象棋你马没了
    企业数字化建设有哪些路线可以选择?
    Vue3 学习-组件通讯(二)
    SpringBoot 接口访问频率限制
    历时三个月,史上最详细的Spring注解驱动开发系列教程终于出炉了,给你全新震撼
    【双目深度估计】——stereo net 分层细化实时网络
  • 原文地址:https://blog.csdn.net/qq_43369406/article/details/127932597