• K近邻分类算法实战教程


    K近邻(K-Nearest Neighbor ,简称KNN ) 是有监督非线性、非参数分类算法,非参数表示对数据集及其分布没有任何假设。它是最简单、最常用的分类算法之一,广泛应用于金融、医疗等领域。

    K近邻算法

    KNN算法中的k表示邻近数据结点的数量,其算法过程如下:

    • 选择邻近结点数量K

    • 计算出测试数据结点和K个最近结点的距离

    • 在这个K个距离中,对每个分类进行计数

    • 依据少数服从多数原则,将测试数据结点归入在K个点中占比最高的那一类

    对于KNN分类算法,两点的距离计算采用欧式距离。请看下图:
    在这里插入图片描述

    假设数据集包括两类,分别为红色和蓝色表示。我们选择k为5,即基于欧式距离考虑5个最近结点,所以当测试新数据点时,5个结点,其中国三个蓝色、两个红色。则认为新数据点分类为蓝色。

    实战示例

    鸢尾花数据集(Iris)包括3种鸢尾(setosa, virginica, versicolor)各50个样本以及多个变量的数据集,是由英国统计学家和生物学家Ronald Fisher在其1936年的论文《The use of multiple measurements in taxonomic problems》中首次引用。Fisher从每个样本中测量了萼片和花瓣的长度和宽度等4个特征,并结合这4个特征建立了一个线性判别模型来区分不同的物种。

    • 加载数据集,并查看概要信息
    
    # Loading data
    data(iris)
       
    # Structure 
    str(iris)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 执行KNN分类
    # 加载依赖包
    library(e1071)
    library(caTools)
    library(class)
    
    # 加载数据
    # data(iris)
    # head(iris)
      
    # 把数据集分为训练集和测试集
    split <- sample.split(iris, SplitRatio = 0.7)
    train_cl <- subset(iris, split == "TRUE")
    test_cl <- subset(iris, split == "FALSE")
      
    # 标准化特征变量
    train_scale <- scale(train_cl[, 1:4])
    test_scale <- scale(test_cl[, 1:4])
    
    # 使用k=1 拟合 KNN 分类模型 
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 1)
    # classifier_knn
      
    # 计算混淆矩阵
    cm <- table(test_cl$Species, classifier_knn)
    cm
    
    #           classifier_knn
    #            setosa versicolor virginica
    # setosa         20          0         0
    # versicolor      0         19         1
    # virginica       0          0        20
      
    # 模型评估 - 计算样本错误率
    misClassError <- mean(classifier_knn != test_cl$Species)
    # print(paste('Accuracy =', 1-misClassError))
    # [1] "Accuracy = 0.933333333333333"  
    
    # K = 3
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 3)
    misClassError <- mean(classifier_knn != test_cl$Species)
    print(paste('Accuracy =', 1-misClassError))
    # [1] "Accuracy = 0.933333333333333"
    
    # K = 5
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 5)
    misClassError <- mean(classifier_knn != test_cl$Species)
    print(paste('Accuracy =', 1-misClassError))
    # [1] "Accuracy = 0.95"
    
    # K = 7
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 7)
    misClassError <- mean(classifier_knn != test_cl$Species)
    print(paste('Accuracy =', 1-misClassError))
    # [1] "Accuracy = 0.966666666666667"
    
    # K = 15
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 15)
    misClassError <- mean(classifier_knn != test_cl$Species)
    print(paste('Accuracy =', 1-misClassError))
    # [1] "Accuracy = 0.983333333333333"
    
    # K = 19
    classifier_knn <- knn(train = train_scale,
                          test = test_scale,
                          cl = train_cl$Species,
                          k = 19)
    misClassError <- mean(classifier_knn != test_cl$Species)
    print(paste('Accuracy =', 1-misClassError))
    
    # [1] "Accuracy = 0.966666666666667"
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85

    当k为15时,模型的准确率达到98.3%,比k为1、3、5、7时的准确率更高。k为19时的精度为96.7%,这意味着增加k值不会增加精度,因此K为15更为合适。

    KNN优劣

    KNN方法思路简单,易于理解,易于实现,无需估计参数。
    该算法在分类时有两个主要不足。当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数 。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点 。

  • 相关阅读:
    EXCEL根据某列的数字N,增加N-1行相同的数据
    全波形反演的深度学习方法: 第一章 基本概念
    数学基础之概率论1
    【Python】Python安装指定版本库
    单链表oj (上),详细的过程分析,每道题有多种解题思路,一定会有所收获
    【开源】SpringBoot框架开发新能源电池回收系统
    【Python】常量和变量类型
    web前端期末大作业:美食文化网页设计与实现——美食餐厅三级(HTML+CSS+JavaScript)
    C++开发学习笔记3
    abp中iquery类使用orderBy接口功能报错问题
  • 原文地址:https://blog.csdn.net/neweastsun/article/details/125474160