• 机器学习常见的sampling策略 附PyTorch实现


    简单的采样策略

    首先介绍三种简单采样策略:

    1. Instance-balanced sampling, 实例平衡采样。
    2. Class-balanced sampling, 类平衡采样。
    3. Square-root sampling, 平方根采样。

    它们可抽象为:

    pj=njqi=1Cniq,

    pj表示从j类采样数据的概率;C表示类别数量;nj表示j类样本数;q{1,0,12}
    Instance-balanced sampling
    最常见的数据采样方式,其中每个训练样本被选择的概率相等(q=1)。j类被采样的概率pjIB与j类样本数nj成正比,即pjIB=nji=1Cni

    Class-balanced sampling
    实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:pjCB=1C。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。
    Square-root sampling
    平方根采样最常见的变体,q=12

    由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:

    import numpy as np
    from torch.utils.data.sampler import WeightedRandomSampler
    def get_sampler(sampling_type, targets):
    cls_counts = np.bincount(targets)
    if sampling_type == 'instance-balanced':
    cls_weights = cls_counts / np.sum(cls_counts)
    elif sampling_type == 'class-balanced':
    cls_num = len(cls_counts)
    cls_weights = [1. / cls_num] * cls_num
    elif sampling_type == 'square-root':
    sqrt_and_sum = np.sum([num**0.5 for num in cls_counts])
    cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts]
    else:
    raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root')
    cls_weights = np.array(cls_weights)
    return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)

    WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

    在模拟的长尾数据集测试下:

    import torch
    from torch.utils.data import Dataset, DataLoader
    torch.manual_seed(0)
    np.random.seed(0)
    class LongTailDataset(Dataset):
    def __init__(self, num_classes, max_samples_per_class):
    self.num_classes = num_classes
    self.max_samples_per_class = max_samples_per_class
    # Generate number of samples for each class inversely proportional to class index
    self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
    self.total_samples = sum(self.samples_per_class)
    # Generate targets for the dataset
    self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)])
    def __len__(self):
    return self.total_samples
    def __getitem__(self, idx):
    # For simplicity, just return the index as the data
    return idx, self.targets[idx]
    # Parameters
    num_classes = 25
    max_samples_per_class = 1000
    # Create dataset
    dataset = LongTailDataset(num_classes, max_samples_per_class)
    # Create dataloader
    batch_size = 64
    sampler1 = get_sampler('instance-balanced', dataset.targets.numpy())
    sampler2 = get_sampler('class-balanced', dataset.targets.numpy())
    sampler3 = get_sampler('square-root', dataset.targets.numpy())
    dataloader1 = DataLoader(dataset, batch_size=64, sampler=sampler1)
    dataloader2 = DataLoader(dataset, batch_size=64, sampler=sampler2)
    dataloader3 = DataLoader(dataset, batch_size=64, sampler=sampler3)
    for (_, target1), (_, target2), (_, target3) in zip(dataloader1, dataloader2, dataloader3):
    print('Instance-balanced:')
    cls_idx, cls_counts = np.unique(target1.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    print('-'*20)
    print('Class-balanced:')
    cls_idx, cls_counts = np.unique(target2.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    print('-'*20)
    print('Square-root:')
    cls_idx, cls_counts = np.unique(target3.numpy(), return_counts=True)
    print(f'Class indices: {cls_idx}')
    print(f'Class counts: {cls_counts}')
    break # just show one batch

    Output:

    Instance-balanced:
    Class indices: [ 0 1 2 3 5 16 22 23]
    Class counts: [43 9 5 2 2 1 1 1]
    --------------------
    Class-balanced:
    Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 20 21 23]
    Class counts: [21 8 6 4 2 1 2 2 3 3 1 2 1 1 1 1 2 1 1 1]
    --------------------
    Square-root:
    Class indices: [ 0 1 2 3 4 5 6 9 10 21 22 23]
    Class counts: [37 8 3 6 3 1 1 1 1 1 1 1]

    混合采样策略

    最早的混合采样是在 0epocht时采用Instance-balanced采样,tepochT时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)pj也发生变化:

    pjPB(t)=(1tT)pjIB+tTpjCB

    t表示当前epoch,T表示总epoch数。

    不平衡数据集下的采样策略

    不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即pj=1nj

    ...
    elif sampling_type == 'inverse':
    cls_weights = 1. / cls_counts
    ...

    在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:

    En=(1βn)/(1β), where β=(N1)/N.

    这里N表示数据集样本总数。

    相关代码:

    ...
    elif sampling_type == 'effective':
    beta = (len(targets) - 1) / len(targets)
    cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
    ...

    Output

    Effective number:
    Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 20 21 22 23 24]
    Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1]

    在和上面一样的模拟长尾数据集上,采样的结果更加均衡。

    参考文献

    1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
    2. torch.utils.data.WeightedRandomSampler
    3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
  • 相关阅读:
    【最佳实践】高可用mongodb集群(1分片+3副本):规划及部署
    Django 创建项目及其他常用命令
    maven在centos7中配置教程
    学习c++的第十天
    利用 mviews on prebuilt table 进行增量刷新数据
    SpringCloudAlibaba-Nacos集群
    Docker之nacos集群部署(详细教你搭建)
    Java学习----前端5
    Java并发之AQS整理:为什么要使用AQS、AQS核心代码流程
    【JavaWeb】JSP基本语法、指令、九大内置对象、四大作用域
  • 原文地址:https://www.cnblogs.com/zh-jp/p/18124824