• 《统计学习方法》第三章习题


    KNN算法是一种基本分类与回归方法,其假设给定一个训练数据集,其中的实例类别已定。分类时,对新的实例,根据其k个最邻近的训练实例的类别,通过多数表决等方式进行预测

    knn三要素:
    距离度量算法:一般使用欧氏距离。也可以使用其他距离:曼哈顿距离、切比雪夫距离、闵可夫斯基距离等。
    k值的确定:k值越小,模型整体变得越复杂,越容易过拟合。通常使用交叉验证法来选取最优k值
    分类决策:一般使用多数表决,即在 k 个邻近的训练点中的多数类决定输入实例的类。可以证明,多数表决规则等价于经验风险最小化。

    3.2利用例题3.2构造的kd树求点x=(3,4.5)的最近邻点

    kd树构造步骤
    1.开始:构造根结点,根结点对应于包含T的l维空间的超矩形区域。
    选择x(0)为坐标轴,以T中所有实例在该坐标轴上的中位数为切分点,将根结点对应的超矩形区域切分为两个子区域。切分由通过切分点并于坐标轴x(0)垂直的超平面实现。由根结点生成深度为1的左、右子结点,分别对应坐标轴x(0)小于和大于切分点的子区域,将落在切分超平面上的实例点保存在根结点。(也可以先对各个维度计算方差,选取最大方差的维度作为候选划分维度(方差越大,表示此维度上数据越分散);对切分维度上的值进行排序,选取中间的点为切分节点;按照切分节点的切分维度对空间进行一次划分;对上述子空间递归以上操作,直到空间只包含一个数据点。分而治之,且循环选取坐标轴。从方差大的维度来逐步切分,可以取得更好的切分效果及树的平衡性。)
    2.对深度为j的结点,选择x(j mod k)为切分坐标轴(也可选择方差最大的坐标轴),重复1所述的切分过程。
    3.直到两个子区域没有实例存在时停止,从而形成kd树的区域划分

    搜索kd树最近邻点步骤
    1.在kd树中从根结点出发,递归向下查找包含目标点x的叶结点;
    2.以此叶结点为当前最近点;
    3.递归向上回退,在每个结点进行以下操作:
          -如果该结点保存到实例点比当前最近点距离目标点更近,则以该实例点位当前最近点。
          -检查该结点的另一子结点是否与以目标点为球心,以目标点与当前最近点间的距离为半径的超球体相交。如果相交,则移动到另一子结点进行递归搜索,否则向上回退。
          -当回退到根结点时,搜索结束,当前最近点即为x的最近邻点。

    KNN python实现

    这里只实现了最近邻,如果是实现k近邻搜索,可以参考knn

    import pandas as pd
    import numpy as np
    from sklearn.datasets import load_iris
    import matplotlib.pyplot as plt
    import heapq
    class TreeNode():
        def __init__(self, x, left=None, right=None, parent=None):
            self.left = left
            self.right = right
            self.parent = parent
            self.val = x  # x是个k维向量
            self.vis = False
    
    class KNN():
        def __init__(self, k):
            self.k = k
            # 选择x(1)为坐标轴,将所有实例的x(1)坐标中的中位数作为切分点。np.median(nums)
            # self.root = TreeNode(np.zeros(self.k))
    
        def create_kd_tree(self, root, parent, j, data):
            """
            :param root: 当前节点
            :param parent: 当前节点的父节点
            :param j: 树的深度
            :param data: 该区域数据集
            :return: kd树根节点
            """
            if len(data)==0:
                return None
            else:
                # 对深度为j的节点,选择x(j%k+1)为切分坐标轴, 以该区域的x(j%k+1)坐标轴上的中位数为切分点
                x = sorted(data, key=lambda x : x[j%self.k])
                # 中位数坐标就是排序后的中间位置
                target = x[len(x)//2]
                root.val = target
                root.left = self.create_kd_tree(TreeNode(np.zeros(self.k)), root, j + 1, x[:len(x) // 2])
                root.right = self.create_kd_tree(TreeNode(np.zeros(self.k)), root,  j + 1, x[len(x)//2+1:])
                root.parent = parent
                return root
    
        def distance(self, a, b):
            dist = np.sqrt(np.sum([np.square(a[i] - b[i]) for i in range(len(a))]))
            return dist
    
        def search_kd_tree(self, target, j, root):
            """
            kd树搜索最近邻
            :param target: 目标点
            :param j: 树的深度
            :param root: 当前节点
            :return: 目标点的最近邻点
            """
            # 获取当前是用哪个坐标轴进行划分
            axis = j%self.k
            #########################递归寻找叶子节点作为最近点#################
            if target[axis]<root.val[axis]:
                if root.left:
                    nearest = self.search_kd_tree(target, j+1, root.left)
                else:
                    nearest = root.val # 找到叶子节点,作为当前最近点
            else:
                if root.right:
                    nearest = self.search_kd_tree(target, j+1, root.right)
                else:
                    nearest = root.val # 找到叶子节点,作为当前最近点
            #####################回溯######################################
            root.vis = True # 标记当前节点已经访问过,避免搜索时重复访问
            now_node = root.val
            # 计算当前节点与目标节点的欧氏距离
            dist1 = self.distance(now_node, target)
            # 计算当前最近节点与目标节点的欧氏距离
            dist2 = self.distance(nearest, target)
            near = dist2 # 用near记录时与哪个节点最近,后面判断超球体是否与其兄弟节点相交需要用到
            # 如果当前节点比保存的最近点距离更小,则更新当前最近点
            if dist1<dist2:
                nearest = now_node
                near = dist1
            # 检查超球体是否与其兄弟节点相交
            if root.parent:
                # 判断目标值是否与其父节点划分的另一区域相交,如果相交,则要重新递归寻早最近邻
                dist = abs(root.parent.val[axis]-target[axis]) # 某向量到某超平面距离就是对应坐标轴上的值相减
                if dist<near:
                    # 判断哪个是兄弟节点,然后从其兄弟节点开始递归搜索kd树
                    if root.parent.left and root.parent.left != root and not root.parent.left.vis:
                        nearest = self.search_kd_tree(target, j+1, root.parent.left)
                    if root.parent.right and root.parent.right != root and not root.parent.right.vis:
                        nearest = self.search_kd_tree(target, j+1, root.parent.right)
            return nearest
    
        def liner_test(self, target, data):
            """
            线性搜索最近邻
            :param target: 目标点
            :param data: 数据集
            :return: 目标点的最近邻点
            """
            mindist=np.inf
            nearest=np.zeros(self.k)
            for x in data:
                dis = self.distance(target,x)
                if dis<mindist:
                    mindist=dis
                    nearest=x
            return nearest
    
    
    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    k=len(data[0])
    knn = KNN(k)
    root = knn.create_kd_tree(TreeNode(np.zeros(k)),None, 0, data)
    print(knn.search_kd_tree([3,4.5], 0, root))
    print(knn.liner_test([3,4.5],data))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112

    输出结果
    在这里插入图片描述

  • 相关阅读:
    C语言之网络编程(一)域名解析
    【Spring】Spring源码中占位符解析器PropertyPlaceholderHelper的使用
    1.线性表
    App 软件开发《判断5》试卷及答案
    Python的pytest框架(6)--测试钩子(hooks)
    计算机毕业设计SSM 校园疫情防控系统【附源码数据库】
    图像处理之图像复原[逆滤波、维纳滤波、约束最小二乘法、Lucy-Richardson和盲解卷积复原]
    Educational Codeforces Round 138 (Rated for Div. 2) B. Death‘s Blessing
    AWS SAA-C03 #146
    PMP_第12章章节试题
  • 原文地址:https://blog.csdn.net/qq_42714262/article/details/127835492