• 决策树(Decision tree)基本原理与基于scikit-learn的实现


    决策树(Decision tree,DT)是一类常见的机器学习方法,属于监督学习的一种。它通过给定的训练数据,计算各种情况发生的概率,在此基础上选择合适的划分并构造决策树,用于数据类别的预测判断,同时也可以进行数据的拟合回归。

    (一)理论基础

    (1)基本模型

    设一个数据集有p个指标(属性)[1,2,..,p],指标i中的值种类个数pi(各个指标内的值个数不一定相同),种类数为n。决策树的形式大致为:

    上图为了表示简便,决策树从上到下依次按指标1,2...划分。利用这样的一个决策树,我们可以判断任意给定的数据所属类别。例如,一个数据指标1的值为value12,指标2的值为value22......我们在决策树从上到下依次寻找,直到最后的叶节点,所得类别即为该数据所属类别。

    然而,在实际问题中,很多时候我们所遇到的数据并不都是离散的值,而是连续的数值。对于连续属性,可采用离散属性离散化技术——例如将数从小到大排列,用若干划分点把数据分成若干部分,其中最简单的策略是采用二分法(bi-partition)进行处理(C4.5决策树算法采用该方法)。

    所以,构造决策树的关键在于如何选择最优划分属性,即应该按照怎样的指标顺序依次划分数据,对于连续属性,我们还要确定划分点。

    (2)构造方法

    对于给定的数据集,我们首先需要确定数据集划分,划分数据集的原则是:将无序的数据变得更加有序。实现该方法可以采用信息论度量信息。在这里兔兔先介绍所用到的相关概念。

    信息熵(information entropy):度量样本集合纯度的一种常用指标,也成为香农熵(shannon entorpy)。值越小,信息纯度越高。

    H(X)=-\sum_{x}p(x)log_{2}(p(x))

    其中X表示数据样本,x为X中的可能取值,p(x)表示x发生(出现)的概率。在本文所研究的问题中,我们可以认为值在样本中出现的频率为概率,所以可以用x出现次数除以样本总数作为p(x)。例如,一个样本中某一指标值为[1,1,1,1,1,0,0,0,1,1,2,2],则信息熵为:

    -\[ \frac{7}{12}log(\frac{7}{12})+\frac{3}{12}log(\frac{7}{12})+\frac{2}{12}log(\frac{2}{12}) ]\\ =1.384

    信息增益(information gain):描述一个属性区分数据样本的能力。

    Gain(X,a)=H(X)-\sum_{v=1}^V\frac{|X_v|}{|X|}H(X_v)

    其中X表示数据样本,a表示划分的属性,该属性有V个可能取值。决策树中信息熵与信息增益的计算,信息熵部分都是计算数据中关于种类的信息熵,H(X)表示树的上一层所有数据中种类的信息熵,H(Xv)是计算属性a划分下值为v的数据中种类的信息熵。

    属性111221122
    属性2aabbbbbbccccaabb
    种类aaabbbaa

    对于上面的数据,我们可以先按照属性1进行划分。计算H(X):

    H(X)=-[\frac{4}{8}log(\frac{4}{8})+\frac{4}{8}log(\frac{4}{8})]=1

    在属性1中,有两个可能取值1、2,取值为1的个数为4,该取值下类别a个数为2,b个数为2;取值为2的个数为4,该取值下类别a个数为3,b个数为1。所以信息增益中第二个部分为:

    -\sum_{v=1}^{V}\frac{|X_v|}{|X|}H(X_{v})=-\{ \frac{4}{8}\times[ -(\frac{2}{4}log\frac{2}{4}+\frac{2}{4}log\frac{2}{4})]+\frac{4}{8}[-(\frac{3}{4}log\frac{3}{4}+\frac{1}{4}log\frac{1}{4})] \} \\ =-0.906

    所以,信息增益为:

    Gain=1-0.906=0.094

    (注:这里的信息熵、信息增益的知识借用了信息论中的知识,但是与之又有所区别。在信息论中熵常用H表示,这里常用Ent;信息增益在信息论中也称为相对熵、KL散度、信息散度)。

    所以,利用信息增益的方法,我们可以确定属性划分的方法,即每次计算各个属性的信息增益,选择使得信息增益最大的属性划分来划分属性,构建决策树。

    对于连续属性,若采取二分法,可以将数从小到大依次排列[a1,a2,......,an],取所有相邻的两个数的中点作为候选划分点,这样的划分点一共有n-1个。之后用各个待划分点把数据分成两个部分,计算信息增益,信息增益值最大的对应划分点即为我们所找的划分点。

    (3)其它方法

    除了采用信息增益(ID3决策树采用此法)寻找划分数据集的属性,也可以采用增益率(gain ratio)(C4.5决策树算法采用此法)、基尼指数(gini index)(CART决策树采用此法)等方法。

    决策树通常会出现过拟合的情况,此时可以采用剪枝(pruning)的方法来处理,其基本策略有:预剪枝(prepruning)、后剪枝(postpruning)。预剪枝是在决策树生成的过程中,对每个节点在划分前先进行估计,若当前节点的划分不能带来决策树泛化能力的提升,则停止划分,标记当前节点为叶节点。后剪枝是先从训练集生成决策树,然后从下往上对非叶节点进行考察,若将该节点对应的子树替换成叶节点可以带来决策树泛化性能的提升。则将该子树替换为叶节点。

    (4)算法实现

    1.信息熵的计算

    1. import numpy as np
    2. def entropy(dataset):
    3. n=len(dataset)#样本数
    4. label={} #统计各类的个数
    5. for data in dataset:
    6. if data[-1] not in label.keys(): #数据类别在每行最后一列
    7. label[data[-1]]=0
    8. label[data[-1]]+=1
    9. ent=0
    10. for key in label:
    11. p=label[key]/n
    12. ent+=-p*np.log2(p)
    13. return ent
    14. a=[[1,2,'a'],
    15. [2,3,'a'],
    16. [3,4,'b']]
    17. print(entropy(a)) #以a为例计算香农熵

    2.由某一属性a划分数据

    1. def splitDataset(dataset,a,value):
    2. '''dataset:数据样本
    3. a:属性,这里是数据中的第a列
    4. value:属性a中的某一取值value
    5. '''
    6. newdataset=[]#划分之后的数据
    7. for data in dataset:
    8. if data[a]==value: #判断数据中该属性第data个值是否为value
    9. newdata=data[0:a]
    10. newdata.extend(data[a+1:]) #抽取除属性a的数据
    11. newdataset.append(newdata)
    12. return newdataset
    1. print(splitDataset(a,a=0,value=1))
    2. '''-----------------------------'''
    3. >>>[[2, 'a']]

    3.由信息增益确定数据集划分

    1. def choosefeature(dataset):
    2. nf=len(dataset[0])-1 #样本属性个数
    3. baseEnt=entropy(dataset) #信息增益中的第一部分
    4. bestinfogain=0
    5. for i in range(nf):
    6. featueList=[data[i] for data in dataset] #属性i的所有数据
    7. value=set(featueList) #属性i中所有可能取值集合
    8. newEnt=0 #信息增益的第二部分
    9. for v in value:
    10. subDataset=splitDataset(dataset,i,v)
    11. p=len(subDataset)/len(dataset)
    12. newEnt+=p*entropy(subDataset)
    13. infogain=baseEnt-newEnt #计算利用属性i划分的信息增益
    14. if infogain>bestinfogain:
    15. bestinfogain=infogain #选择最大的信息增益
    16. bestfeature=i #选择最大信息增益下的属性
    17. return bestfeature

    4.构建决策树

    1. def majority(classList):
    2. '''选择classList中个数最多的那个值'''
    3. classdict={}
    4. for c in classList:
    5. if c not in classList:
    6. classdict[c]=0
    7. classdict[c]+=1
    8. a=max(classdict.values())
    9. for key,item in classdict.items():
    10. if item==a:
    11. return key
    12. def creatTree(dataset,featureLabel):
    13. '''dataset:数据集
    14. featureLabel:各属性的标签'''
    15. classList=[data[-1] for data in dataset] #样本类别
    16. if classList.count(classList[0])==len(classList):
    17. return classList[0] #若划分完全后叶节点所有值相同,停止划分返回该值
    18. if len(dataset[0])==1:
    19. return majority(classList) #若已经划分完,返回该叶节点中值个数最多的那个值
    20. bestfeature=choosefeature(dataset) #选择划分
    21. bestfeatueLabel=featureLabel[bestfeature] #该划分属性的标签
    22. Tree={bestfeatueLabel:{}}
    23. del (featureLabel[bestfeature]) #去掉已划分的属性标签
    24. value=[data[bestfeature] for data in dataset] #选取该属性的值
    25. valueSet=set(value)
    26. for v in valueSet:
    27. subLabel=featureLabel
    28. Tree[bestfeatueLabel][v]=creatTree(splitDataset(dataset,bestfeature,v),subLabel) #递归创建树
    29. return Tree
    1. print(creatTree(a,featureLabel=['feature1','feature2']))
    2. '''---------------------------------------------------'''
    3. >>>{'feature1': {1: 'a', 2: 'a', 3: 'b'}}

    (二)决策树分类

    (1)基本方法

    在sklearn中,可以利用DecisionTreeClassifier()来实现。

    1. from sklearn import tree
    2. Tree=tree.DecisionTreeClassifier()
    3. traindata=[[11,23],
    4. [34,22],
    5. [55,66],
    6. [3,1],
    7. [444,24]]
    8. trainlabel=[1,1,2,2,3]
    9. Tree.fit(traindata ,trainlabel)
    10. yp=Tree.predict([[11,22]])
    11. print(yp)

    在DecisionTreeClassifer()中,有以下参数:

    criterion:衡量属性划分质量。'gini'基尼指数;'log_loss'或'entropy'香农信息增益;

    splitter:用于在每个节点选择拆分策略。'best'最佳拆分;'random'随机拆分。

    max_depth:树的最大深度。默认为None,树一直扩展到叶节点全为同一类别数据或包含少于min_samples_split样本。

    min_samples_split:拆分内部节点所需的最小样本数。

    min_samples_leaf:叶节点所需最小样本数。

    min_weight_fraction_leaf:需要位于叶节点的权重总和,即所有输入样本的最小加权分数,默认样本具有相同的权重。

    max_features:寻找最佳分割时要考虑的特征数量。

    max_leaf_nodes:以最佳优先方式使用‘max_leaf_nodes’生成决策树。

    random_state:控制估计器estimator随机性。

    min_impurity_decrease:若此拆分的不纯度减少大于等于此值,则该节点被拆分。

    ccp_alpha:最小成本复杂度修剪,决策树剪枝,默认不剪枝。

    (2)决策树图绘制

    在训练好决策树后,可以采用plot_tree()画决策树图。

    1. from sklearn import tree
    2. import pandas as pd
    3. data=pd.DataFrame(pd.read_csv('Dry_Bean_Dataset.csv'))
    4. x=data.loc[:,'Area':'ShapeFactor4']
    5. y=data.loc[:,'Class']
    6. Tree=tree.DecisionTreeClassifier()
    7. Tree.fit(x,y)
    8. tree.plot_tree(Tree)
    9. plt.show()

     当然,我们也可以设置树的深度。例如,我们让树的最大深度max_depth=10。则树图为:

    在plot_tree()中,有以下参数:

    max_depth:树的最大深度,若为None,生成完全树,默认为None。

    feature_names:每个特征的名,参数值为各个特征的名组成的列表。

    class_names:每个目标类的名称,True为显示名称,默认False。

    labels:是否显示一些信息标签。'all'显示每个节点;'root'只显示根节点;'none'不显示任何节点,默认为'all'。

    impurity:是否显示不纯度,默认为True。

    node_ids:是否显示每个节点的ID号,True表示显示,默认False。

    filled:图形是否填充颜色,默认False。

    rounded:若为True,绘制圆角节点并使用Helvetica。默认False。

    fontsize:字体大小。

    ax:要绘制到的轴(与matplotlib结合,例如绘制子图时可以使用,将ax轴传给ax)。

    1. tree.plot_tree(Tree,max_depth=2,feature_names=data.columns,class_names=True,label='all',rounded=True,fontsize=10,filled=True,
    2. proportion=True,impurity=True,node_ids=True)
    3. plt.show()

    (三)决策树回归

    (1)基本方法

    在sklearn中,决策树回归用DecisionTreeRegressor()来实现。

    1. from sklearn import tree
    2. Tree=tree.DecisionTreeRegressor()
    3. x=np.linspace(0,10,90)
    4. y=np.sin(x)
    5. x=np.reshape(x,(-1,1))
    6. Tree.fit(x,y)
    7. testx=np.linspace(0,10,100)
    8. testx=np.reshape(testx,(-1,1))
    9. yp=Tree.predict(testx)
    10. '''score=Tree.score(x,y)
    11. print(score)'''
    12. plt.scatter(x,y,color='red')
    13. plt.plot(testx,yp,color='green')
    14. plt.show()

    运行结果如下:

    在DecisionTreeRegressor()中的参数与前面DecisionTreeRegession()中的参数基本相同,但是也有不同之处。

    criterion:默认'squared_error'均方误差;'absolute_error'平均绝对误差;'friedman_mse'弗里德曼;'poisson'泊松偏差。

    (四)总结

    本文通过决策树的基本概念,以信息增益的划分方法为例阐述分类决策树的原理,并采用sklearn来实现决策树分类与树回归。实际上决策树的内容较为广泛,方法众多,并且针对过拟合的情况有各种剪枝处理,如预剪枝、后剪枝等。

  • 相关阅读:
    MySQL数据库管理基本操作(一)
    多端用户APP
    ansible - Role
    矩阵分析与应用
    Flink中的Window计算-增量计算&全量计算
    【微机原理笔记】第 1 章 - 微型计算机基础概论
    Consul(注册中心)部署
    周赛补题(力扣、acwing)
    外汇天眼:虚假宣传 FCA发出最新警告 远离该平台!
    嵌入式基础知识-DMA
  • 原文地址:https://blog.csdn.net/weixin_60737527/article/details/126310756