• UNet网络模型:数据集制作


    UNet网络模型:数据集制作

    一般语义分割的原始图片和标记图片在以下目录下:

    image-20230924235329548

    SegmentationClass:标记图片。

    JPEGImages:原始图片。

    数据集往往都是很多的图片等信息,对于数据集类来说,一个类里有所有数据的信息,并且可以用下标进行访问,就像访问数组一样。

    pytorch中有Dataset类,用于创建自定义数据集类。我们可以用继承Dataset类来实现数据集类。在一个数据集类中,需要确定数据集所在的位置,也要用__getitem__实现下标访问,用__len__实现数据集大小。

    对于标记图片和原始图片可能大小或者格式不同,需要将这两个图片的格式统一。

    代码参考UNet数据集制作及代码实现_哔哩哔哩_bilibili,根据该UP主的代码,加上我的个人整理和理解。(这个UP主的代码感觉很好,很规范

    目录的树形结构:

    data/ (path) 
     |
     |
     |------- SegmentationClass/
     |               |
     |               |-- 标记图像
     |               |-- ...
     |               |-- 标记图像
     |
     |
     |
     |
     |------- JPEGImages/
                    |
                    |-- 原始图像
                    |-- ...
                    |-- 原始图像
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    原始图像和标记图像处理

    from PIL import Image # 用于读取和处理图像
    
    # 用于加载图像并保存图像大小一致
    # path表示图像文件的路径,size表示需要调整的目标大小,默认为(256, 256)
    def keep_image_size_open(path, size = (256, 256)):
    
        # 通过Image.open(path)方法打开path路径的图像,并将返回的图像对象保存在变量img中。
        img = Image.open(path)
        # 获取输入图像的最大边长
        tmp = max(img.size)
        # 创建一个新的空白图像对象mask,大小为(tmp, tmp),颜色为(0, 0, 0),即黑色。
        mask = Image.new("RGB", (tmp, tmp), (0,0,0))
        # 将原始图像img粘贴到mask图像对象的左上角,使得原始图像位于黑色背景之上。
        mask.paste(img, (0,0))
        # 将图像调整为目标大小,并返回调整后的图像对象mask。
        mask = mask.resize(size)
        return mask
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    数据集制作

    import os # 路径操作
    from torch.utils.data import Dataset # 继承 pytorch中的Dataset
    from utils import * # 引入自己写的: 将原始图片和标记图片格式和规格统一
    from torchvision import transforms # 用于数据预处理
    
    ################################################
    #
    #   transforms.Compose用于将多个变换操作组合在一起。
    # 这里只有一个 ToTensor(),用于将图像数据转为张量格式
    #
    ###############################################
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    ##############################################
    #
    # 通过继承Dataset类,来实现自定义数据集类
    #
    ##############################################
    class MyDataset(Dataset):
    
        # path表示数据集的路径
        def __init__(self, path):
            self.path = path
            # path路径下的目录
            self.name = os.listdir(os.path.join(path, "SegmentationClass"))
        
        # 返回数据集长度
        def __len__(self):
            return len(self.name)
        
        # 实现下标访问
        def __getitem__(self, index):
            segment_name = self.name[index] # xx.png 得到的是:path/SegmentationClass/图像名字
            segment_path = os.path.join(self.path, "SegmentationClass", segment_name) # 得到标记图像路径
            # 得到原始图像路径。segment_name.replace(xxx,kkk) 表示如果标记图像和原始图像格式不同,进行后缀替代,进而得到正确的原始图像路径
            image_path = os.path.join(self.path, "JPEGImages", segment_name.replace("png", "jpg")) 
    
            # 将图像进行统一规格化
            segment_image = keep_image_size_open(segment_path)
            image = keep_image_size_open(image_path)
    
            # 返回经过transform处理后的图像张量
            return transform(image), transform(segment_image)
    
    # 测试是否正确
    if __name__ == "__main__":
        data = MyDataset(r"E:\Undergraduate\School\Scientific_research\ML\Machine_Learning\lab\图形分割\Unet模型总\data")
        print(data[0][0].shape)
        print(data[0][1].shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    UNet网络制作

    UNet数据集制作及代码实现_哔哩哔哩_bilibili

  • 相关阅读:
    自己实现扫描全盘文件的函数。
    全渠道客服体验:Rocket.Chat 的无缝互动 | 开源日报 No.41
    芯片安全(security)
    hive-学习微博日志分析
    入坑计算机视觉必备的图像基础
    vue前端中v-model与ref的区别
    使用Velodyne传感器生成的点云进行快速且稳健的聚类处理:一个C++实践指南
    Windows下DataGrip连接Hive
    【Python绘画】画笑脸简笔画
    元宇宙006 | 你不了解的“养成系”人工智能就此诞生
  • 原文地址:https://blog.csdn.net/qq_63432403/article/details/133256863