简单的采样策略
首先介绍三种简单采样策略:
- Instance-balanced sampling, 实例平衡采样。
- Class-balanced sampling, 类平衡采样。
- Square-root sampling, 平方根采样。
它们可抽象为:
Instance-balanced sampling
最常见的数据采样方式,其中每个训练样本被选择的概率相等(
Class-balanced sampling
实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:
Square-root sampling
平方根采样最常见的变体,
由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:
import numpy as np from torch.utils.data.sampler import WeightedRandomSampler def get_sampler(sampling_type, targets): cls_counts = np.bincount(targets) if sampling_type == 'instance-balanced': cls_weights = cls_counts / np.sum(cls_counts) elif sampling_type == 'class-balanced': cls_num = len(cls_counts) cls_weights = [1. / cls_num] * cls_num elif sampling_type == 'square-root': sqrt_and_sum = np.sum([num**0.5 for num in cls_counts]) cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts] else: raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root') cls_weights = np.array(cls_weights) return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)
WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。
在模拟的长尾数据集测试下:
import torch from torch.utils.data import Dataset, DataLoader torch.manual_seed(0) np.random.seed(0) class LongTailDataset(Dataset): def __init__(self, num_classes, max_samples_per_class): self.num_classes = num_classes self.max_samples_per_class = max_samples_per_class # Generate number of samples for each class inversely proportional to class index self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)] self.total_samples = sum(self.samples_per_class) # Generate targets for the dataset self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)]) def __len__(self): return self.total_samples def __getitem__(self, idx): # For simplicity, just return the index as the data return idx, self.targets[idx] # Parameters num_classes = 25 max_samples_per_class = 1000 # Create dataset dataset = LongTailDataset(num_classes, max_samples_per_class) # Create dataloader batch_size = 64 sampler1 = get_sampler('instance-balanced', dataset.targets.numpy()) sampler2 = get_sampler('class-balanced', dataset.targets.numpy()) sampler3 = get_sampler('square-root', dataset.targets.numpy()) dataloader1 = DataLoader(dataset, batch_size=64, sampler=sampler1) dataloader2 = DataLoader(dataset, batch_size=64, sampler=sampler2) dataloader3 = DataLoader(dataset, batch_size=64, sampler=sampler3) for (_, target1), (_, target2), (_, target3) in zip(dataloader1, dataloader2, dataloader3): print('Instance-balanced:') cls_idx, cls_counts = np.unique(target1.numpy(), return_counts=True) print(f'Class indices: {cls_idx}') print(f'Class counts: {cls_counts}') print('-'*20) print('Class-balanced:') cls_idx, cls_counts = np.unique(target2.numpy(), return_counts=True) print(f'Class indices: {cls_idx}') print(f'Class counts: {cls_counts}') print('-'*20) print('Square-root:') cls_idx, cls_counts = np.unique(target3.numpy(), return_counts=True) print(f'Class indices: {cls_idx}') print(f'Class counts: {cls_counts}') break # just show one batch
Output:
Instance-balanced: Class indices: [ 0 1 2 3 5 16 22 23] Class counts: [43 9 5 2 2 1 1 1] -------------------- Class-balanced: Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 20 21 23] Class counts: [21 8 6 4 2 1 2 2 3 3 1 2 1 1 1 1 2 1 1 1] -------------------- Square-root: Class indices: [ 0 1 2 3 4 5 6 9 10 21 22 23] Class counts: [37 8 3 6 3 1 1 1 1 1 1 1]
混合采样策略
最早的混合采样是在
t表示当前epoch,T表示总epoch数。
不平衡数据集下的采样策略
不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即
... elif sampling_type == 'inverse': cls_weights = 1. / cls_counts ...
在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:
这里N表示数据集样本总数。
相关代码:
... elif sampling_type == 'effective': beta = (len(targets) - 1) / len(targets) cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts)) ...
Output
Effective number: Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 20 21 22 23 24] Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1]
在和上面一样的模拟长尾数据集上,采样的结果更加均衡。
参考文献
- Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
- torch.utils.data.WeightedRandomSampler
- Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.