• 以K近邻算法为例,使用交叉验证优化模型最佳参数


    参数优化

    K近邻算法的介绍可以参考之前的文章:K近邻分类算法的Python代码实现
    在K近邻算法中,参数是K,表示K个最近的邻居。不同的K,可能会导致不同的预测结果。
    如下图所示,K=1时,红色方块被预测为蓝色圆点;K=3时,红色方块被预测为黄色上三角;当K增长至9时,红色方块又被预测为蓝色原点。

    那么在K近邻算法中,哪个K才是最佳的K?以及如何得到这个最佳的K?本文介绍的解决方案是交叉验证。

    交叉验证

    下图以5折交叉验证为例,绘制了交叉验证的计算过程。首先是把训练集等分为5份,构造出5组“子训练集和子测试集”的组合,其中子训练集占比4/5,子测试集占比1/5,不同组的组合中,子训练集和子测试集会有一些不同。
    然后对于每一组,选择一个k值后,可以分别计算得到K近邻算法的性能指标score1,score2,score3,score4和score5。
    最后把这5个值取平均值后,可以得到K_score。

    从以上的描述可知,交叉验证的价值包括:(1)每一个数据都作为子测试集被测试过,所以使用交叉验证后,可以有效防止过拟合;(2)对于每个K,都可以得到一个K_score,如果计算了不同K对应的K_score,通过对比K_score值的大小,就可以找到最佳的K值。

    代码实现

    交叉验证有3步:数据集划分、分别计算模型指标和统计平均值。
    以下基于Python代码实现了这个过程。
    为了方便后续与sklearn的结果进行细致比较,代码中并未直接统计平均值,而是分别输出模型的指标,此处使用的指标是f1_score,如果对指标含义不清楚,可以参考之前的文章:机器学习模型常用性能指标和Python代码实现
    需要注意的是,此处在划分数据集时,直接使用了自带的函数StratifiedKFold。
    对于分类问题来说,此处不能使用KFold,否则会和sklearn无法对齐。

    from sklearn import neighbors
    from sklearn.metrics import f1_score
    from sklearn.model_selection import StratifiedKFold
    
    
    def k_folder_by_self(X, y, cross_num, k):
        
        # 不能使用KFold,否则结果无法对齐
        kf = StratifiedKFold(cross_num)
        i = 0
        for train_index, test_index in kf.split(X, y):
            X_train = X[train_index]
            y_train = y[train_index]
    
            X_test = X[test_index]
            y_test = y[test_index]
    
            clf = neighbors.KNeighborsClassifier(k)
            clf.fit(X_train, y_train)
            print('i: {}, f1_score: {}'.format(i, f1_score(y_test, clf.predict(X_test))))
            i += 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    在sklearn中,已有交叉验证的函数cross_val_score,所以此处直接调用即可。

    from sklearn import neighbors
    from sklearn.model_selection import cross_val_score
    
    
    def k_folder_by_sklearn(X, y, cross_num, k):
        scores = cross_val_score(neighbors.KNeighborsClassifier(n_neighbors=k), X, y, cv=cross_num, scoring='f1')
        for i in range(0, cross_num):
            print('i: {}, f1_score: {}'.format(i, scores[i]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    代码测试

    本节通过一个实例来测试自编代码和skearn代码的计算结果。
    此处使用了cancer数据集,并设置K=3(最近邻数量为3),cv=5(5折交叉验证)。

    if __name__ == '__main__':
        _, X, y = breast_cancer()
    
        k = 3
        cv = 5
    
        print('=======f1_score by self=======')
        k_folder.k_folder_by_self(X, y, cv, k)
        print('=======f1_score by sklearn=======')
        k_folder.k_folder_by_sklearn(X, y, cv, k)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    计算结果如下,显然对于5次计算的结果,两个程序都是完全一致的,由此验证了逻辑的正确性。

    实际应用

    在遇到实际问题时,使用交叉验证的方法如下图所示。右边是机器学习的常规流程,左边为交叉验证的流程,其插入在训练集的后面。

    从上图可以看出,一共包含4个步骤:数据集拆分为训练集和测试集;训练集使用交叉验证,得到最佳参数;使用最佳参数针对训练集进行重新训练;使用测试集计算模型的性能指标。以下为一个实例的具体代码。

    from sklearn import neighbors
    from sklearn.metrics import f1_score
    from sklearn.model_selection import validation_curve, train_test_split
    import numpy as np
    
    
    def cal_validation_curve(X, y, cv):
    
        # 数据集拆分
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
        
        # 交叉验证
        param_name = 'n_neighbors'
        param_range = range(1, 20)
        train_scores, test_scores = validation_curve(neighbors.KNeighborsClassifier(), X_train, y_train, cv=cv,
                                                     param_name=param_name, param_range=param_range, scoring='f1')
        test_scores_mean = np.mean(test_scores, axis=1)
        best_k = np.argmax(test_scores_mean) + 1
        
        # 使用最佳参数重新训练
        classifier = neighbors.KNeighborsClassifier(n_neighbors=best_k)
        classifier.fit(X_train, y_train)
        
        # 计算最终性能指标
        f1 = f1_score(y_test, classifier.predict(X_test))
    
        return best_k, f1
    
    
    if __name__ == '__main__':
        
        _, X, y = breast_cancer()
        cv = 5
        best_k, f1 = cal_validation_curve(X, y, cv)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
  • 相关阅读:
    力扣20-有效的括号——栈实现
    【数据结构】Java实现数据结构的前置知识,时间复杂度空间复杂度,泛型类的讲解
    一张图搞定英文星期、月份、季节总也搞不定的星期,月份,季节,一张图搞定,还有必用的常见搭配,再也不担心用错介词了~
    软件方法(下)第8章Part14:不要因为偷懒或炫耀而定义组合
    Docker: exec命令浅析
    详细讲解如何使用Java连接Kafka构建生产者和消费者(带测试样例)
    多线程和多进程的区别与联系
    Vue16 绑定css样式 style样式
    从任正非的内部信,看系统开发公司如何度过寒冬
    快慢指针思想(Hare & Tortoise 算法)
  • 原文地址:https://blog.csdn.net/taozibaby/article/details/126693917