• torchvision.datasets.ImageFolder使用详解


    一、数据集组织方式

    ImageFolder是一个通用的数据加载器,它要求我们以下面这种格式来组织数据集的训练、验证或者测试图片。

    root/1/xxx.png
    root/1/xxy.png
    root/1/xxz.png
    . . . 
    root/2/12.png
    . . .
    root/3/123.png
    . . .
    root/4/356.png
    . . .
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    对于上面的root,假设data文件夹在.py文件的同级目录中,那么root一般都是如下这种形式:./data/train 和 ./data/valid
    在这里插入图片描述在这里插入图片描述

    二、ImageFolder参数详解

    dataset=torchvision.datasets.ImageFolder(
    root, transform=None,
    target_transform=None,
    loader=datasets.folder.default_loader,
    is_valid_file=None)

    参数详解:

    • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
    • transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
    • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
    • loader:表示数据集加载方式,通常默认加载方式即可。
    • is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

    返回的dataset都有以下三种属性:

    • self.classes:用一个 list 保存类别名称
    • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    • self.imgs:保存(img-path, class) tuple的 list

    三、程序案例

    from torchvision.datasets import ImageFolder
    from torchvision import transforms
     
    #加上transforms
    normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
    transform=transforms.Compose([
        transforms.RandomCrop(180),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
        normalize
    ])
     
    dataset=ImageFolder('./data/train',transform=transform)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    我们得到的dataset,它的结构就是[(img_data, class_id),(img_data, class_id),… ],下面我们打印第一个元素:

    print(dataset[0])
    
    • 1
    输出:
    (tensor([[[-0.5137, -0.4667, -0.4902,  ..., -0.0980, -0.0980, -0.0902],
             [-0.5922, -0.5529, -0.5059,  ..., -0.0902, -0.0980, -0.0667],
             [-0.5373, -0.5294, -0.4824,  ..., -0.0588, -0.0824, -0.0196],
             ...,
             [-0.3098, -0.3882, -0.3725,  ..., -0.4353, -0.4510, -0.4196],
             [-0.2863, -0.3647, -0.3725,  ..., -0.4431, -0.4118, -0.4196],
             [-0.3412, -0.3569, -0.3882,  ..., -0.4667, -0.4588, -0.4196]],
            [[-0.6157, -0.5686, -0.5922,  ..., -0.2863, -0.2784, -0.2706],
             [-0.6941, -0.6549, -0.6078,  ..., -0.2784, -0.2784, -0.2471],
             [-0.6392, -0.6314, -0.5843,  ..., -0.2471, -0.2706, -0.2078],
             ...,
             [-0.4431, -0.5059, -0.5059,  ..., -0.5608, -0.5765, -0.5451],
             [-0.4196, -0.4824, -0.5059,  ..., -0.5686, -0.5373, -0.5451],
             [-0.4745, -0.4902, -0.5294,  ..., -0.5922, -0.5843, -0.5451]],
            [[-0.6627, -0.6157, -0.6549,  ..., -0.5059, -0.5216, -0.5137],
             [-0.7412, -0.7020, -0.6706,  ..., -0.4980, -0.5216, -0.4902],
             [-0.6863, -0.6784, -0.6471,  ..., -0.4667, -0.4902, -0.4275],
             ...,
             [-0.6000, -0.6549, -0.6627,  ..., -0.6784, -0.6941, -0.6627],
             [-0.5765, -0.6314, -0.6471,  ..., -0.6863, -0.6549, -0.6627],
             [-0.6314, -0.6314, -0.6392,  ..., -0.7098, -0.7020, -0.6627]]]), 0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    下面我们再看一下dataset的三个属性:

    print(dataset.classes)  #根据分的文件夹的名字来确定的类别
    print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
    print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
    '''
    输出:
    ['cat', 'dog']
    {'cat': 0, 'dog': 1}
    [('./data/train\\cat\\1.jpg', 0), 
     ('./data/train\\cat\\2.jpg', 0), 
     ('./data/train\\dog\\1.jpg', 1), 
     ('./data/train\\dog\\2.jpg', 1)]
    '''
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    自己编写datasets.ImageFolder

    class CustomImageFolderDataset(datasets.ImageFolder):
        def __init__(self,
                     root,
                     transform=None,
                     target_transform=None,
                     loader=datasets.folder.default_loader,
                     is_valid_file=None,
                     low_res_augmentation_prob=0.0,
                     crop_augmentation_prob=0.0,
                     photometric_augmentation_prob=0.0,
                     ):
            super(CustomImageFolderDataset, self).__init__(root,
                                                           transform=transform,
                                                           target_transform=target_transform,
                                                           loader=loader,
                                                           is_valid_file=is_valid_file)
            self.root = root
            self.low_res_augmentation_prob = low_res_augmentation_prob
            self.crop_augmentation_prob = crop_augmentation_prob
            self.photometric_augmentation_prob = photometric_augmentation_prob
            self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112),
                                                                    scale=(0.2, 1.0),
                                                                    ratio=(0.75, 1.3333333333333333))
            self.photometric = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0)
    
            self.tot_rot_try = 0
            self.rot_success = 0
    
        def __getitem__(self, index):
            """
            Args:
                index (int): Index
    
            Returns:
                tuple: (sample, target) where target is class_index of the target class.
            """
            path, target = self.samples[index]
            sample = self.loader(path)
    
            if 'WebFace' in self.root:
                # swap rgb to bgr since image is in rgb for webface
                # 将 rgb 交换为 bgr,因为图像在 rgb 中用于 webface
                sample = Image.fromarray(np.asarray(sample)[:,:,::-1])
    
            sample, _ = self.augment(sample)
            if self.transform is not None:
                sample = self.transform(sample)
            if self.target_transform is not None:
                target = self.target_transform(target)
    
            return sample, target
    
        def augment(self, sample):
    
            # crop with zero padding augmentation
            if np.random.random() < self.crop_augmentation_prob:
                # RandomResizedCrop augmentation
                new = np.zeros_like(np.array(sample))
                orig_W, orig_H = F._get_image_size(sample)
                i, j, h, w = self.random_resized_crop.get_params(sample,
                                                                self.random_resized_crop.scale,
                                                                self.random_resized_crop.ratio)
                cropped = F.crop(sample, i, j, h, w)
                new[i:i+h,j:j+w, :] = np.array(cropped)
                sample = Image.fromarray(new.astype(np.uint8))
                crop_ratio = min(h, w) / max(orig_H, orig_W)
            else:
                crop_ratio = 1.0
    
            # low resolution augmentation
            if np.random.random() < self.low_res_augmentation_prob:
                # low res augmentation
                img_np, resize_ratio = low_res_augmentation(np.array(sample))
                sample = Image.fromarray(img_np.astype(np.uint8))
            else:
                resize_ratio = 1
    
            # photometric augmentation
            if np.random.random() < self.photometric_augmentation_prob:
                fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
                self.photometric.get_params(self.photometric.brightness, self.photometric.contrast,
                                                      self.photometric.saturation, self.photometric.hue)
                for fn_id in fn_idx:
                    if fn_id == 0 and brightness_factor is not None:
                        sample = F.adjust_brightness(sample, brightness_factor)
                    elif fn_id == 1 and contrast_factor is not None:
                        sample = F.adjust_contrast(sample, contrast_factor)
                    elif fn_id == 2 and saturation_factor is not None:
                        sample = F.adjust_saturation(sample, saturation_factor)
    
            information_score = resize_ratio * crop_ratio
            return sample, information_score
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
  • 相关阅读:
    java毕业设计创意产业园区管理mybatis+源码+调试部署+系统+数据库+lw
    行至青鸟 | 为学习保驾护航的“教学管理”
    怎样考过PMP?零经验的我这样做
    JavaScript设计模式(七):架构型设计模式-Widget模式
    【JS逆向系列】某乎x96参数3.0版本与jsvmp进阶
    【C/PTA】数组进阶练习(一)
    考虑区域多能源系统集群协同优化的联合需求侧响应模型(matlab代码)
    前端工程师应该如何去创业?
    gitlab 通过变量连接自建K8S
    烟台专利的类型个人申请专利有哪些好处?
  • 原文地址:https://blog.csdn.net/weixin_54546190/article/details/126123675