• 数据增强Mixup原理与代码解读


    paper:mixup: Beyond Empirical Risk Minimization

    存在的问题

    • 经验风险最小化(Empirical Risk Minimization, ERM)允许大型神经网络强行记住训练数据(而不是去学习、泛化),即使加了很强的正则化,或是在随机分配标签的分类问题中,这个问题也依然存在。
    • 使用ERM原则训练的神经网络,当在训练样本分布之外的数据上进行评估时,预测结果会发生显著的变化,这被称为对抗性样本。

    解决这个问题的一个方法是邻域风险最小化(Vicinal Risk Minimization, VRM),即通过数据增强在原始样本的基础上构造更多的样本,但数据增强中需要人类知识来描述训练数据中每个样本的邻域,比如翻转、缩放等。因此VRM也有两点不足

    • 数据增强过程依赖数据集,因此需要专家知识
    • 数据增强只建模同一类别之间的邻域关系

    Mix-up

    针对上述问题,本文提出一种data-agnostic的数据增强方法mixup,

    其中xi,xj" role="presentation" style="position: relative;">xi,xj是从训练集中随机挑选的两张图像,yi,yj" role="presentation" style="position: relative;">yi,yj是对应的one-hot标签,通过先验知识:特征向量的线性插值和对应目标的线性插值还是对应的关系,构造了新的样本(x~,y~)" role="presentation" style="position: relative;">(x~,y~)。其中λ" role="presentation" style="position: relative;">λ通过β(α,α)" role="presentation" style="position: relative;">β(α,α)分布获得,α" role="presentation" style="position: relative;">α是超参。

    此外,作者提到了一些通过实验得到的结论 

    1. 通过实验发现三个或三个以上样本的组合不能带来进一步的精度提升,反而会增加计算成本。
    2. 作者的实现方法是通过一个单独的data loader获得一个batch的数据,然后在random shuffle后对这一个batch内的数据使用mixup,作者发现这种策略的效果很好,同时减少了I/O。
    3. 只对相同类别的样本进行mixup并不会带来精度的提升。

    实现

    torchvision版本

    这里通过roll方法将batch内的图片向后平移一个,然后与原batch进行mixup,相当于batch内的每张图片都和相邻的一张进行mixup,roll方法详见 torch.roll()

    1. class RandomMixup(torch.nn.Module):
    2. """Randomly apply Mixup to the provided batch and targets.
    3. The class implements the data augmentations as described in the paper
    4. `"mixup: Beyond Empirical Risk Minimization" `_.
    5. Args:
    6. num_classes (int): number of classes used for one-hot encoding.
    7. p (float): probability of the batch being transformed. Default value is 0.5.
    8. alpha (float): hyperparameter of the Beta distribution used for mixup.
    9. Default value is 1.0.
    10. inplace (bool): boolean to make this transform inplace. Default set to False.
    11. """
    12. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
    13. super().__init__()
    14. if num_classes < 1:
    15. raise ValueError(
    16. f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
    17. )
    18. if alpha <= 0:
    19. raise ValueError("Alpha param can't be zero.")
    20. self.num_classes = num_classes
    21. self.p = p
    22. self.alpha = alpha
    23. self.inplace = inplace
    24. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
    25. """
    26. Args:
    27. batch (Tensor): Float tensor of size (B, C, H, W)
    28. target (Tensor): Integer tensor of size (B, )
    29. Returns:
    30. Tensor: Randomly transformed batch.
    31. """
    32. if batch.ndim != 4:
    33. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
    34. if target.ndim != 1:
    35. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
    36. if not batch.is_floating_point():
    37. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
    38. if target.dtype != torch.int64:
    39. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
    40. if not self.inplace:
    41. batch = batch.clone()
    42. target = target.clone()
    43. if target.ndim == 1:
    44. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
    45. if torch.rand(1).item() >= self.p:
    46. return batch, target
    47. # It's faster to roll the batch by one instead of shuffling it to create image pairs
    48. batch_rolled = batch.roll(1, 0)
    49. target_rolled = target.roll(1, 0)
    50. # Implemented as on mixup paper, page 3.
    51. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
    52. batch_rolled.mul_(1.0 - lambda_param)
    53. batch.mul_(lambda_param).add_(batch_rolled)
    54. target_rolled.mul_(1.0 - lambda_param)
    55. target.mul_(lambda_param).add_(target_rolled)
    56. return batch, target
    57. def __repr__(self) -> str:
    58. s = (
    59. f"{self.__class__.__name__}("
    60. f"num_classes={self.num_classes}"
    61. f", p={self.p}"
    62. f", alpha={self.alpha}"
    63. f", inplace={self.inplace}"
    64. f")"
    65. )
    66. return s

    mmclassification版本

    这里是通过randperm将batch内的图片打乱,然后与原batch进行mixup,并且得到\(\lambda\)的方法与torchvision也不一样。

    1. class BatchMixupLayer(BaseMixupLayer):
    2. r"""Mixup layer for a batch of data.
    3. Mixup is a method to reduces the memorization of corrupt labels and
    4. increases the robustness to adversarial examples. It's
    5. proposed in `mixup: Beyond Empirical Risk Minimization
    6. `
    7. This method simply linearly mix pairs of data and their labels.
    8. Args:
    9. alpha (float): Parameters for Beta distribution to generate the
    10. mixing ratio. It should be a positive number. More details
    11. are in the note.
    12. num_classes (int): The number of classes.
    13. prob (float): The probability to execute mixup. It should be in
    14. range [0, 1]. Default sto 1.0.
    15. Note:
    16. The :math:`\alpha` (``alpha``) determines a random distribution
    17. :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
    18. a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
    19. distribution.
    20. """
    21. def __init__(self, *args, **kwargs):
    22. super(BatchMixupLayer, self).__init__(*args, **kwargs)
    23. def mixup(self, img, gt_label):
    24. one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
    25. lam = np.random.beta(self.alpha, self.alpha)
    26. batch_size = img.size(0)
    27. index = torch.randperm(batch_size)
    28. mixed_img = lam * img + (1 - lam) * img[index, :]
    29. mixed_gt_label = lam * one_hot_gt_label + (
    30. 1 - lam) * one_hot_gt_label[index, :]
    31. return mixed_img, mixed_gt_label
    32. def __call__(self, img, gt_label):
    33. return self.mixup(img, gt_label)

    目标检测中的mixup

    在文章Bag of Freebies for Training Object Detection Neural Networks 中,对两张图片mixup后只是合并了两张图中的所有gt box,并没有对类别标签进行mixup。但文章提到"weighted loss indicates the overall loss is the summation of multiple objects with ratio 0 to 1 according to image blending ratio they belong to in the original training images",即在计算loss时对每个物体的loss按mixup时的系数进行加权求和。

    参考

    图像分类训练技巧之数据增强篇 - 知乎

    MMClassification 数据增强介绍(二) - 知乎

  • 相关阅读:
    笔记本电脑远程控制jetson nano/nx桌面的三种方法
    高级IO-epoll
    Redis 要凉了?
    python+requests+unittest执行自动化接口测试!
    MATLB|多微电网及分布式能源交易
    Java还是要系统学习,阿里面试失败的经验总结,最终获字节offer
    笛卡尔树【模板】
    ICML 2019 | SGC:简单图卷积网络
    ffmpeg命令行处理视频,学习记录
    [精选] 多账号统一登录,你如何设计?
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/126070745