• CutMix原理与代码解读


    paper:CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

    前言

    之前的数据增强方法存在的问题:

    mixup:混合后的图像在局部是模糊和不自然的,因此会混淆模型,尤其是在定位方面。

    cutout:被cutout的部分通常用0或者随机噪声填充,这就导致在训练过程中这部分的信息被浪费掉了。

    cutmix在cutout的基础上进行改进,cutout的部分用另一张图像上cutout的部分进行填充,这样即保留了cutout的优点:让模型从目标的部分视图去学习目标的特征,让模型更关注那些less discriminative的部分。同时比cutout更高效,cutout的部分用另一张图像的部分进行填充,让模型同时学习两个目标的特征。

    从下图可以看出,虽然Mixup和Cutout都提升了模型的分类精度,但在若监督定位和目标检测性能上都有不同程度的下降,而CutMix则在各个任务上都获得了显著的性能提升。

    CutMix

    cutmix的具体过程如下

    其中M{0,1}W×H" role="presentation" style="position: relative;">M{0,1}W×H是一个binary mask表明从两张图中裁剪的patch的位置,和mixup一样,λ" role="presentation" style="position: relative;">λ也是通过β(α,α)" role="presentation" style="position: relative;">β(α,α)分布得到的,在文章中作者设置α=1" role="presentation" style="position: relative;">α=1,因此λ" role="presentation" style="position: relative;">λ是从均匀分布(0,1)" role="presentation" style="position: relative;">(0,1)中采样的。

    为了得到mask,首先要确定cutmix的bounding box的坐标B=(rx,ry,rw,rh)" role="presentation" style="position: relative;">B=(rx,ry,rw,rh),其值通过下式得到

    即\(\lambda\)确定了patch与原图的面积比,即A图cutout的面积越大,标签融合时A图的比例越小。

    代码实现

    下面是torchvision的官方实现

    1. class RandomCutmix(torch.nn.Module):
    2. """Randomly apply Cutmix to the provided batch and targets.
    3. The class implements the data augmentations as described in the paper
    4. `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
    5. `_.
    6. Args:
    7. num_classes (int): number of classes used for one-hot encoding.
    8. p (float): probability of the batch being transformed. Default value is 0.5.
    9. alpha (float): hyperparameter of the Beta distribution used for cutmix.
    10. Default value is 1.0.
    11. inplace (bool): boolean to make this transform inplace. Default set to False.
    12. """
    13. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
    14. super().__init__()
    15. if num_classes < 1:
    16. raise ValueError("Please provide a valid positive value for the num_classes.")
    17. if alpha <= 0:
    18. raise ValueError("Alpha param can't be zero.")
    19. self.num_classes = num_classes
    20. self.p = p
    21. self.alpha = alpha
    22. self.inplace = inplace
    23. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
    24. """
    25. Args:
    26. batch (Tensor): Float tensor of size (B, C, H, W)
    27. target (Tensor): Integer tensor of size (B, )
    28. Returns:
    29. Tensor: Randomly transformed batch.
    30. """
    31. if batch.ndim != 4:
    32. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
    33. if target.ndim != 1:
    34. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
    35. if not batch.is_floating_point():
    36. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
    37. if target.dtype != torch.int64:
    38. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
    39. if not self.inplace:
    40. batch = batch.clone()
    41. target = target.clone()
    42. if target.ndim == 1:
    43. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
    44. if torch.rand(1).item() >= self.p:
    45. return batch, target
    46. # It's faster to roll the batch by one instead of shuffling it to create image pairs
    47. batch_rolled = batch.roll(1, 0)
    48. target_rolled = target.roll(1, 0)
    49. # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
    50. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
    51. _, H, W = F.get_dimensions(batch)
    52. r_x = torch.randint(W, (1,))
    53. r_y = torch.randint(H, (1,))
    54. r = 0.5 * math.sqrt(1.0 - lambda_param)
    55. r_w_half = int(r * W)
    56. r_h_half = int(r * H)
    57. x1 = int(torch.clamp(r_x - r_w_half, min=0))
    58. y1 = int(torch.clamp(r_y - r_h_half, min=0))
    59. x2 = int(torch.clamp(r_x + r_w_half, max=W))
    60. y2 = int(torch.clamp(r_y + r_h_half, max=H))
    61. batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
    62. lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
    63. target_rolled.mul_(1.0 - lambda_param)
    64. target.mul_(lambda_param).add_(target_rolled)
    65. return batch, target
    66. def __repr__(self) -> str:
    67. s = (
    68. f"{self.__class__.__name__}("
    69. f"num_classes={self.num_classes}"
    70. f", p={self.p}"
    71. f", alpha={self.alpha}"
    72. f", inplace={self.inplace}"
    73. f")"
    74. )
    75. return s

    实验结果

    从下图可以看出,CutMix在ImageNet上的精度超过了Cutout和Mixup等数据增强方法

    在若监督目标定位方面,CutMix也超过了Mixup和Cutout

    当作为预训练模型迁移到其它下游任务比如目标检测和图像描述时,CutMix也取得了最好的效果

  • 相关阅读:
    Python实现基于DFS和BFS算法的吃豆人寻路实验
    基于springboot网上书城系统
    docker安装elastic search和kibana
    Ubuntu查看系统版本信息
    python批量修改excel单元格内容
    Util应用框架 7.x 来了
    2024年华为OD机试真题-生成哈夫曼树-Java-OD统一考试(C卷)
    【Lilishop商城】No2-2.确定软件架构搭建一(本篇包括MVC框架、持久框架、缓存、认证工具、安全框架等)
    C语言--volatile
    Java多线程(1):线程生命周期
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/126072135