• 【Kaggle比赛常用trick】K折交叉验证、TTA


    一、什么是k折交叉验证

    在训练阶段,我们一般不会使用全部的数据进行训练,而是采用交叉验证的方式来训练。交叉验证(Cross Validation,CV)是机器学习模型的重要环节之一。它可以增强随机性,从有限的数据中获得更全面的信息,减少噪声干扰,从而缓解过拟合,增强模型的泛化能力。

    比赛一般会只给我们训练集,但是测试集我们是看不到的,所以我们一般会将训练集按照一定的方式划分为训练集和验证集。训练集用于模型的训练,验证集用于本地验证,选取最好的pt权重文件,再提交到比赛官网进行测试集的验证。所以如何划分训练集和验证集,让我们最大限度的利用训练集,学习有效的特征,是至关重要的。交叉验证就是做这个事的。

    交叉验证步骤:

    1. 将整个数据集划分为大小相等的K个部分;
    2. 每次选取其中一份作为验证集,其余K-1份作为训练集进行训练;
    3. 重复K次,直至每一份数据都被当作验证集验证了一遍;
    4. 模型的最终精度是通过K个子模型的平均精度来计算的;

    下面这个图可以比较好的诠释上面这个过程:
    在这里插入图片描述

    我们一般不会自己实现这个功能,一般都是调用SKLearn包直接使用,SKlearn帮我们实现了KFold、Stratified KFold、Group KFold和Stratified Group KFold四种方式,下面我一一介绍它们的区别和用法。

    二、常见的几种交叉验证方式

    2.1、KFold

    KFold是最简单的一种K折交叉验证,它的具体步骤如下图所示是一个4折交叉验证,橘色代表验证集(1份),蓝色代表训练集(3份),整个数据集有三个类别(对应图中三种颜色的分布情况);这些数据属于很多个不同的组;

    在这里插入图片描述
    可以看的很清楚,这种K折交叉验证,有两个缺点:

    1. 不适应于数据集样本不均衡的情况,因为很可能会把整个少数的类别划分为验证集或训练集;
    2. 不适应于时间序列问题;

    2.2、Stratified(分层) KFold

    上面讲到,KFold不适应于数据不平衡的问题,所以Stratified KFold(分层)交叉验证就是专门来解决这个问题的。如下图,在分层交叉验证中,数据集依然被划分为K组,但是验证组的目标类别是从各个类中分层抽取出来的,是均匀的,所以就不会存在少数类别被全部划分为验证集或训练集。
    在这里插入图片描述
    特点:可以解决数据不平衡问题,但是不适应于时间序列问题。

    2.3、Group (分组)KFold

    GroupKFold是KFold一个变体,目的在于将group严格分开,就是说同一个group的数据只能出现在训练集或者验证集,不能同时出现在训练集和验证集,如下图:在这里插入图片描述特点:可以将数据的group完全分开,避免高度相似的样本既出现在训练集又测试在验证集。

    2.4、Stratified Group KFold

    Group KFold和Stratified KFold的合体,如下图:

    在这里插入图片描述
    特点:可以将数据的group和标签的class完全分层划开,避免出现样本高度相似和标签分布不均的问题。

    2.5、Time Series Split

    可以解决时间序列相关的问题。对于时间序列数据集,根据时间将数据分为训练和验证,也称为前向链接方法或滚动交叉验证。

    在这里插入图片描述

    使用方式举例:

    skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):
    
    • 1
    • 2

    三、什么是TTA?

    TTA,即Test time augmention,测试时增强。数据增强一般是出现在训练阶段,使用数据增强一般都能提升性能。而测试时数据增强是指在测试的时候,将原图进行数据增强(比如水平翻转、垂直翻转、对角线翻转、旋转等,这里假设使用了3种数据增强),可以得到4张测试图片,对这四张测试图片分布进行推理,得到推理结果。再对三张增强后的推理结果再变换回来(比如我对原图进行水平翻转,得到的mask,再对mask进行水平翻转)。最后就得到了4张预测结果,对这四张预测结果mask对应位置相加取平均,就得到了最终的mask预测果。

    使用方式举例:

           model = build_model(CFG, test_flag=True)
            model.load_state_dict(torch.load(sub_ckpt_path))
               model.eval()
               y_preds = model(images) # [b, c, w, h]
               y_preds   = torch.nn.Sigmoid()(y_preds)
               masks += y_preds
    
               #x,y,xy flips as TTA
               if CFG.tta:
                   flips = [[-1]]  # 水平翻转
                   for f in flips:
                       images_f = torch.flip(images, f)
                       y_preds = model(images_f) # [b, c, w, h]
                       y_preds = torch.flip(y_preds, f)
                       y_preds   = torch.nn.Sigmoid()(y_preds)
                       masks += y_preds
    
            if CFG.tta:
                total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold * 2
            else:
                total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    Reference

    知乎: 常见交叉验证方法汇总

  • 相关阅读:
    什么是网络编程(一)
    Git Commit Message 规范
    阿里巴巴最新总结「百亿级别并发设计手册」GitHub收获70K标星
    显示订单列表【项目 商城】
    LeetCode 2240. Number of Ways to Buy Pens and Pencils【数学,枚举;类欧几里得算法】1399
    防火墙的设置主要是为了防范什么
    linux安装mysql 5.7 完整步骤
    关于多传感器融合方法的总结与思考
    项目范围管理
    徒手撸设计模式-观察者模式
  • 原文地址:https://blog.csdn.net/qq_38253797/article/details/125262197