• 图解kd树+Python实现



    开篇

    在讲解k-近邻算法的时候,我们提供的思路是:对于新到来的样本,计算该样本与训练集中所有样本之间的距离,选取训练集中距离新样本最近的k个样本中大多数样本的类别作为新的样本的类别。

    也就是说,每次都要计算新的样本与训练集中全部样本的距离。但是,在实际应用中,训练集的样本量和特征维度都是比较庞大的,这就导致该算法不得不在计算距离上花费大量的时间,那有没有什么方法可以在时间开销上对之前的k-近邻算法进行优化呢?

    采用以空间来换时间的思想,就引出了今天的主角:kd树

    构造kd树

    kd树是一种二叉树,它可以将k维特征空间中的样本进行划分存储,以便实现快速搜索。

    一头雾水?没关系,来看一个经典的构造kd树的例子。

    现给定一个二维的训练集:
    T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}
    要求构造一个平衡kd树

    复制
      • 第一步,选取第0个维度作为被划分的坐标轴,并按照第0个维度从小到大排列全部样本,得到:
        {(2,3),(4,7),(5,4),(7,2),(8,1),(9,6)}

      复制
        • 第二步,找到第0个维度的中位数对应的样本。注意,这里的中位数与我们之前认知的中位数有些不同,具体表现在:对于本例,第0个维度排序后分别为:2,4,5,7,8,9。按道理中位数应该是(5+7)/2=6,但是,训练集的第0个维度中并没有6,所以,我们需要选取距离6最近的出现在训练集的第0个维度中的数字作为中位数,这里,5和7都是可以的。为了便于编程,我们就统一使用下标较大的位置的数字了: 6//2=3,所以最终选择下标为3位置的数字,即数字7作为第0个维度的中位数。

        • 第三步,以第二步中选取的中位数为基准,并作为当前划分的父节点。将从小到大排序好的样本序列进行划分:第0个维度小于基准的样本被划分到当前划分的父节点的左子树,第0个维度大于基准的样本被划分当前划分的父节点的右子树。此时得到如下的树:

        • 第四步,选取新的维度,按照公式 “新的划分维度=(上一次使用的维度+1)mod  特征总维数” ,得到新的维度为:(0+1) mod 2 = 1 。

          于是以维度1替换维度0,重复第一步到第三步:

          • 对左子树{(5,4),(2,3),(4,7)}按照特征的第一个维度从小到大排序:{(2,3),(5,4),(4,7)},确定中位数下标为3//2=1,所以数字4为中位数;将(5,4)作为当前划分的父节点,第一维度大于4的作为其左子树,第一维度小于4的作为其右子树;

          • 对右子树{(9,6),(8,1)}按照特征的第一个维度从小到大排序:{(8,1),(9,6)},确定中位数下标为2//2=1,所以数字6为中位数;将(9,6)作为当前划分的父节点,第一维度大于6的作为其左子树,第一维度小于6的作为其右子树;

        此时得到的树如下:

        由于此时训练集中所有子区域都已划分完毕(任一子区域中不含样本点),因此kd树就构造完成了。在上面的过程中,每分一次岔,就对应特征空间的一次划分(叶子节点的左右孩子都为空,但这里仍可以看成是一种特殊的分叉【左右分支都为空】) 最终整个特征空间被划分如下:

        现在来用Python实现上述过程。首先定义每个节点的数据结构:

        class Node():
            def __init__(self,lchild,rchild,value):
                self.lchild=lchild#节点的左子树
                self.rchild=rchild#节点的右子树
                self.value=value#节点的数值

        复制

          然后初始化一个KD树的类:

          class KDTree():
              def __init__(self,data):
                  self.dims=len(data[0])#训练集总特征数

          复制

            接下来到了构建kd树的核心步骤,从之前的例子中,可以总结出我们的思路:

            创建kd树的过程是递归的,所以我们可以递归地构造之:
            (1) 递归地构造左子树;
            (2) 递归地构造右子树;
            (3) 构造父节点,将其lchild与构造好的左子树连接,将其rchild与构造好的右子树连接。
            除此之外,还有一些辅助的方法,比如求指定维度的中位数,计算下一个划分维度,将会写成单独的方法以使得创建树的代码更加具有可读性。
            最后,不要忘了递归出口:被划分的子区域没有样本存在时,就退出。

                def create_kdtree(self,current_data,split_dim):
                    #设置递归出口:当全部样本划分完毕时就退出
                    if len(current_data)==0:
                        return None
                    
                    mid=self.cal_current_medium(current_data)#计算中位数所在下标
                    data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

                    #下面三句代码本质上就是二叉树的后序遍历
                    lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
                    rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
                    return Node(lchild,rchild,data_sorted[mid])#连接从根节点出发的左右子树,并返回
                
                #计算下一个划分维度
                def cal_split_dim(self,split_dim):
                    return (split_dim+1) % self.dims
                
                #计算当前维度中位数所在下标
                def cal_current_medium(self,current_data):
                    return len(current_data)//2
              

            复制

              完整的kd树构造代码如下:

              class KDTree():
                  def __init__(self,data):
                      self.dims=len(data[0])#训练集总特征数
                 def create_kdtree(self,current_data,split_dim):
                      #设置递归出口:当全部样本划分完毕时就退出
                      if len(current_data)==0:
                          return None
                      
                      mid=self.cal_current_medium(current_data)#计算中位数所在下标
                      data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

                      #下面三句代码本质上就是二叉树的后序遍历
                      lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
                      rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
                      return Node(lchild,rchild,data_sorted[mid])#连接从根节点出发的左右子树,并返回
                  
                  #计算下一个划分维度
                  def cal_split_dim(self,split_dim):
                      return (split_dim+1) % self.dims
                  
                  #计算当前维度中位数所在下标
                  def cal_current_medium(self,current_data):
                      return len(current_data)//2

              复制

                运行下面的代码,就构造好了一棵kd树:

                dataset = np.array([[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]])#构建训练数据集
                kdtree = KDTree(dataset).create_kdtree(dataset,0)#创建KD树,以特征的第0个维度开始做划分

                复制

                  搜索kd树

                  这里仅实现最近邻搜索。所谓最近邻,就是k-近邻中k取1时的特殊情况。我们还是以具体的例子进行说明。基于上面构造好的kd树,现在来搜索样本点(2, 4.5)的最近邻点。先把之前的图搬过来,对照该图阅读以下步骤会更容易理解:

                  从根节点开始:

                  1. 首先来到第一层:在构造kd树时,由于(7,2)是根据维度0进行划分的,因此需要比较(2,4.5)与(7,2)的第0个维度的大小。由于2<7,因此接下来将搜索(7,2)的左子树(也就是(5,4)节点),反映到划分图上,就是去"过点(7,2)的垂直于横轴的划分线"的左侧进行接下来的搜索;

                  2. 然后来到第二层:在构造kd树时,由于(5,4)是根据维度1进行划分的,因此需要比较(2,4.5)与(5,4)的第1个维度的大小。由于4.5>4,因此接下来将搜索(5,4)的右子树(也就是(4,7)节点),反映到划分图上,就是去"过点(5,4)的垂直于纵轴的划分线"的上侧进行接下来的搜索;

                  3. 接着来到第三层:由于(4,7)已经是叶子节点,无左右孩子,所以从根节点(7,2)到叶子节点的搜索就完成了,当前的最近邻节点就是最后到达的叶子节点,也就是(4,7)。

                  4. 现在,开始从叶子节点(4,7)向上往根节点进行搜索(这也称之为回溯):

                  (1)
                  以(2,4.5)为中心,以(2,4.5)到当前最近邻点(4,7)的距离为半径,画一个圆(这里特征是二维的,所以是圆。一般的,对于高维特征的情况,画出来的是一个超球面),真正的最近邻点一定包含在这个圆的内部。于是当前最近邻点是(4,7),最近距离为半径长度=3.2015;

                  (2)
                  从叶子节点(4,7)返回其父节点(5,4),计算(5,4)与(2,4.5)的距离为3.0413,而3.0413<3.2015,因此当前最近邻点被更新为(5,4),最近距离被更新为3.0413;

                  (3)
                  返回计算父节点(5,4)的另一子节点(这里也就是(2,3)),计算其与目标点(2,4.5)的距离为1.5,而1.5<3.0413,因此当前最近邻点被更新为(2,3),最近距离被更新为1.5;

                  (4)
                  此时父节点(5,4)的另一子节点已经搜索完毕,继续向上回溯搜索那些没有被回溯过的节点,于是来到根节点(7,2),计算(7,2)与(2,4.5)的距离为5.5901,而5.5901>1.5,因此当前最近邻点不变,最近距离也不变。由于已经回溯到了根节点,整个搜索就完毕了,当前最近邻点就是我们最终要找的最近邻点,即(2,3)。

                  现在,让我们用Python程序来实现以上的搜索过程。基于构造kd树的代码,需要增加搜索的方法以及一些小的变动,具体如下:

                  • 由于在前向搜索的过程中,需要知道每个节点是根据哪个维度进行划分的,因此给每个节点增加一个维度属性:split_dim
                  class Node():
                     def __init__(self,lchild,rchild,value,split_dim):
                         self.lchild=lchild#节点的左子树
                         self.rchild=rchild#节点的右子树
                         self.value=value#节点的数值
                         self.split_dim=split_dim#用来做划分的维度

                  复制
                    • 为了便于返回最近邻点和最近距离,将这两个属性添加到kd树的属性中:
                    class KDTree():
                       def __init__(self,data):
                           self.dims=len(data[0])#总特征数
                           self.nearest_point=None
                           self.nearest_distance=np.inf#初始化为无穷大

                    复制
                      • 由于涉及到了距离的比较,因此增加计算两点之间距离的方法:
                         #计算两点之间的欧氏距离
                         def cal_dist(sample1,sample2):
                             return np.sqrt(np.sum((sample1-sample2)**2))

                      复制
                        • 算法将从根节点开始搜索,由于是递归的,所以这里可以先写一个辅助的递归入口函数,真正实现递归的算法写在另一个方法中:
                           #element:目标节点;root:kd树的根节点
                           def get_nearest(self,root,element):
                               search(root,element)#递归地搜索
                               return self.nearest_point,self.nearest.dist

                        复制
                          • 现在来实现递归搜索的过程:
                             def search(self,node,element):
                                if node is  None:
                                  return
                             #计算当前划分维度上目标节点与当前节点的单一维度上的距离
                                dist = node.value[node.split_dim] - element[node.split_dim]
                                #前向搜索
                                if dist>0:#当前节点在目标节点的上侧或左侧(在二维空间中)
                                    self.search(node.lchild,element)#递归地搜索左子树
                                else:#否则,当前节点在目标节点的下侧或右侧(在二维空间中)
                                    self.search(node.rchild,element)#递归地搜索右子树
                                #计算目标节点与当前节点的欧氏距离
                                curr_dist = self.cal_dist(node.value,element)
                                #更新最近邻节点
                                if curr_dist < self.nearest_dist:
                                    self.nearest_dist = curr_dist
                                    self.nearest_point = node
                                    #print(self.nearest_point.value)
                                #回溯
                                #比较“最近距离”是否超过“目标节点与当前节点在当前划分维度上的距离”,超过了就说明可能在当前节点的另一侧子树中存在更近的点,所以需要到当前节点的另一侧子树中去搜索
                                if self.nearest_dist > abs(dist):
                                    #由于是去当前节点的另一侧子树中进行搜索,因此正好与之前的前向搜索相反
                                    if dist>0:
                                        self.search(node.rchild,element)
                                    else:
                                        self.search(node.lchild,element)

                          复制

                            完整代码如下:

                            import numpy as np
                            class Node():
                                def __init__(self,lchild,rchild,value,split_dim):
                                    self.lchild=lchild#节点的左子树
                                    self.rchild=rchild#节点的右子树
                                    self.value=value#节点的数值
                                    self.split_dim=split_dim#用来做划分的维度

                            class KDTree():
                                def __init__(self,data):
                                    self.dims=len(data[0])#总特征数
                                    self.nearest_point=None
                                    self.nearest_dist=np.inf#初始化为无穷大
                                    
                                def create_kdtree(self,current_data,split_dim):
                                    #设置递归出口:当全部样本划分完毕时就退出
                                    if len(current_data)==0:
                                        return None
                                    
                                    mid=self.cal_current_medium(current_data)#计算中位数所在下标
                                    data_sorted=sorted(current_data,key=lambda x:x[split_dim])#按照切分维度从小到大排序

                                    #下面三句代码本质上就是二叉树的后序遍历
                                    lchild=self.create_kdtree(data_sorted[0:mid],self.cal_split_dim(split_dim))#递归地构造左子树
                                    rchild=self.create_kdtree(data_sorted[mid+1:],self.cal_split_dim(split_dim))#递归地构造右子树
                                    return Node(lchild,rchild,data_sorted[mid],split_dim)#连接从根节点出发的左右子树,并返回
                                
                                #计算下一个划分维度
                                def cal_split_dim(self,split_dim):
                                    return (split_dim+1) % self.dims
                                
                                #计算当前维度中位数所在下标
                                def cal_current_medium(self,current_data):
                                    return len(current_data)//2
                                
                                #计算两点之间的欧氏距离
                                def cal_dist(self,sample1,sample2):
                                    return np.sqrt(np.sum((sample1-sample2)**2))
                                    
                                #传入kd树的根节点root和待搜索的点element,搜索element的最近邻点
                                def search(self,node,element):
                                    if node is  None:
                                        return
                              #计算当前划分维度上目标节点与当前节点的单一维度上的距离
                                    dist = node.value[node.split_dim] - element[node.split_dim]
                                    #前向搜索
                                    if dist>0:#当前节点在目标节点的上侧或左侧(在二维空间中)
                                        self.search(node.lchild,element)#递归地搜索左子树
                                    else:#否则,当前节点在目标节点的下侧或右侧(在二维空间中)
                                        self.search(node.rchild,element)#递归地搜索右子树
                                    #计算目标节点与当前节点的欧氏距离
                                    curr_dist = self.cal_dist(node.value,element)
                                    #更新最近邻节点
                                    if curr_dist < self.nearest_dist:
                                        self.nearest_dist = curr_dist
                                        self.nearest_point = node
                                        #print(self.nearest_point.value)
                                    #回溯
                                    #比较“最近距离”是否超过“目标节点与当前节点在当前划分维度上的距离”,超过了就说明可能在当前节点的另一侧子树中存在更近的点,所以需要到当前节点的另一侧子树中去搜索
                                    if self.nearest_dist > abs(dist):
                                        #由于是去当前节点的另一侧子树中进行搜索,因此正好与之前的前向搜索相反
                                        if dist>0:
                                            self.search(node.rchild,element)
                                        else:
                                            self.search(node.lchild,element)
                                
                                 def get_nearest(self,root,element):
                                    self.search(root,element)
                                    return self.nearest_point.value,self.nearest_dist

                            复制

                              现在来测试一下:

                              dataset = np.array([[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]])#构建训练数据集
                              kdtree = KDTree(dataset)#实例化一个kd树对象
                              root=kdtree.create_kdtree(dataset,0)#创建KD树,且以特征的第0个维度开始做划分,最终返回的是根节点
                              nearest_point,nearest_dist=kdtree.get_nearest(root,[2,4.5])#搜索[2,4.5]的最近邻点
                              print('最近邻点:{}\n最近距离:{}'.format(nearest_point,nearest_dist))

                              复制

                                运行结果:

                                最近邻点:[2 3]
                                最近距离:1.5

                                复制

                                  这和之前我们推导的结果是一致的。

                                  最后,感谢互联网上的优秀资源,给本文提供了许多参考。

                                  参考资料:

                                  link

                                • 相关阅读:
                                  工程与建设杂志工程与建设杂志社工程与建设编辑部2022年第3期目录
                                  【Vue】搭建vuex环境
                                  中国首款电音音频类“山野电音”数藏发售来了!
                                  C++ 设计模式:工厂模式
                                  【自然语言处理】深度学习基础
                                  LAB 信号量实现细节
                                  Linux中间件之redis存储原理和字典
                                  metinfo __ 6.0.0 __ file-read
                                  [附源码]计算机毕业设计基于Springboot的中点游戏分享网站
                                  java计算机毕业设计社区健康信息管理系统源码+系统+mysql数据库+lw文档
                                • 原文地址:https://blog.csdn.net/luoganttcc/article/details/134087812