• Pytorch技法:继承Subset类完成自定义数据集拆分


    我们在《torch.utils.data.DataLoader与迭代器转换》中介绍了如何使用Pytorch内置的数据集进行论文实现,如torchvision.datasets。下面是加载内置训练数据集的常见操作:

    from torchvision.datasets import FashionMNIST
    from torchvision.transforms import Compose, ToTensor, Normalize
    RAW_DATA_PATH = './rawdata'
    transform = Compose(
            [ToTensor(),
             Normalize((0.1307,), (0.3081,))
             ]
        )
    train_data = FashionMNIST(
            root=RAW_DATA_PATH,
            download=True,
            train=True,
            transform=transform
        )
    

    这里的train_data做为dataset对象,它拥有许多熟悉,我们可以通过以下方法获取样本数据的分类类别集合、样本的特征维度、样本的标签集合等信息。

    classes = train_data.classes
    num_features = train_data.data[0].shape[0]
    train_labels = train_data.targets
    
    print(classes)
    print(num_features)
    print(train_labels)
    

    输出如下:

    ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    28
    tensor([9, 0, 0,  ..., 3, 0, 5])
    

    但是,我们常常会在训练集的基础上拆分出验证集(或者只用部分数据来进行训练)。我们想到的第一个方法是使用torch.utils.data.random_splitdataset进行划分,下面我们假设划分10000个样本做为训练集,其余样本做为验证集:

    from torch.utils.data import random_split
    k = 10000
    train_data, valid_data = random_split(train_data, [k, len(train_data)-k])
    

    注意我们如果打印train_datavalid_data的类型,可以看到显示:

    <class 'torch.utils.data.dataset.Subset'>
    

    已经不再是torchvision.datasets.mnist.FashionMNIST对象,而是一个所谓的Subset对象!此时Subset对象虽然仍然还存有data属性,但是内置的targetclasses属性已经不复存在,比如如果我们强行访问valid_datatarget属性:

    valid_target = valid_data.target
    

    就会报如下错误:

    'Subset' object has no attribute 'target'
    

    但如果我们在后续的代码中常常会将拆分后的数据集也默认为dataset对象,那么该如何做到代码的一致性呢?

    这里有一个trick,那就是以继承SubSet类的方式的方式定义一个新的CustomSubSet类,使新类在保持SubSet类的基本属性的基础上,拥有和原本数据集类相似的属性,如targetsclasses等:

    from torch.utils.data import Subset
    class CustomSubset(Subset):
        '''A custom subset class'''
        def __init__(self, dataset, indices):
            super().__init__(dataset, indices)
            self.targets = dataset.targets # 保留targets属性
            self.classes = dataset.classes # 保留classes属性
    
        def __getitem__(self, idx): #同时支持索引访问操作
            x, y = self.dataset[self.indices[idx]]      
            return x, y 
    
        def __len__(self): # 同时支持取长度操作
            return len(self.indices)
    

    然后就引出了第二种划分方法,即通过初始化CustomSubset对象的方式直接对数据集进行划分(这里为了简化省略了shuffle的步骤):

    import numpy as np
    from copy import deepcopy
    origin_data = deepcopy(train_data)
    train_data = CustomSubset(origin_data, np.arange(k))
    valid_data = CustomSubset(origin_data, np.arange(k, len(origin_data))-k)
    

    注意,CustomSubset类的初始化方法的第二个参数indices为样本索引,我们可以通过np.arange()的方法来创建。

    然后,我们再访问valid_data对应的classestarges属性:

    print(valid_data.classes)
    print(valid_data.targets)
    

    此时,我们发现可以成功访问这些属性了:

    ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    tensor([9, 0, 0,  ..., 3, 0, 5])
    

    当然,CustomSubset的作用并不只是添加数据集的属性,我们还可以自定义一些数据预处理操作。我们将类的结构修改如下:

    class CustomSubset(Subset):
        '''A custom subset class with customizable data transformation'''
        def __init__(self, dataset, indices, subset_transform=None):
            super().__init__(dataset, indices)
            self.targets = dataset.targets
            self.classes = dataset.classes
            self.subset_transform = subset_transform
    
        def __getitem__(self, idx):
            x, y = self.dataset[self.indices[idx]]
            
            if self.subset_transform:
                x = self.subset_transform(x)
          
            return x, y   
        
        def __len__(self): 
            return len(self.indices)
    

    我们可以在使用样本前设置好数据预处理算子:

    from torchvision import transforms
    valid_data.subset_transform = transforms.Compose(\
        [transforms.RandomRotation((180,180))])
    

    这样,我们再像下列这样用索引访问取出数据集样本时,就会自动调用算子完成预处理操作:

    print(valid_data[0])
    

    打印结果缩略如下:

    
    (tensor([[[-0.4242, -0.4242, -0.4242, ......-0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), 9)
    

    引用


    __EOF__

  • 本文作者: 猎户座
  • 本文链接: https://www.cnblogs.com/orion-orion/p/15906086.html
  • 关于博主: 本科CS系蒟蒻,机器学习半吊子,并行计算混子。
  • 版权声明: 欢迎您对我的文章进行转载,但请务必保留原始出处哦(*^▽^*)。
  • 声援博主: 如果您觉得文章对您有帮助,可以点击文章右下角推荐一下。
  • 相关阅读:
    pom文件引用本地对象
    安卓案例:选项菜单
    java毕业设计民宿管理平台mybatis+源码+调试部署+系统+数据库+lw
    安卓自动化之minicap截图
    Postman使用
    Ubuntu源码编译Mysql常见的错误
    linux 学习 day 07 进程信号
    22 年国内最牛的 Java 面试八股文合集(全彩版),不接受反驳
    手机直播提词器哪个软件好?这两款软件值得收藏
    性能分析5部曲:瓶颈分析与问题定位,如何快速解决瓶颈?
  • 原文地址:https://www.cnblogs.com/orion-orion/p/15906086.html