• 机器学习---决策树分类代码


    1. 计算数据集的香农熵

    1. from numpy import *
    2. import numpy as np
    3. import pandas as pd
    4. from math import log
    5. import operator
    6. #计算数据集的香农熵
    7. def calcShannonEnt(dataSet):
    8. numEntries=len(dataSet)
    9. labelCounts={}
    10. #给所有可能分类创建字典
    11. for featVec in dataSet:
    12. currentLabel=featVec[-1]
    13. if currentLabel not in labelCounts.keys():
    14. labelCounts[currentLabel]=0
    15. labelCounts[currentLabel]+=1
    16. shannonEnt=0.0
    17. #以2为底数计算香农熵
    18. for key in labelCounts:
    19. prob = float(labelCounts[key])/numEntries
    20. shannonEnt-=prob*log(prob,2)
    21. return shannonEnt

    香农熵公式: 

    数据集: 

    2. 对离散变量划分数据集 

    1. #对离散变量划分数据集,取出该特征取值为value的所有样本
    2. def splitDataSet(dataSet,axis,value):
    3. retDataSet=[]
    4. for featVec in dataSet:
    5. if featVec[axis]==value:
    6. reducedFeatVec=featVec[:axis]
    7. reducedFeatVec.extend(featVec[axis+1:])
    8. retDataSet.append(reducedFeatVec)
    9. return retDataSet

    这个函数用于划分数据集。它的作用是从给定的数据集中,根据指定的特征和取值,提取出符合条

    件的样本集合。函数的输入参数包括数据集(dataSet)、特征的索引(axis)和特征取值

    (value)。在函数内部,通过遍历数据集中的每个样本(featVec),判断该样本在指定特征上的

    取值是否与给定的取值相等。如果相等,则将该样本添加到结果集合(retDataSet)中。为了将样

    本添加到结果集合中,需要先创建一个新的样本(reducedFeatVec),它是将原样本中指定特征

    的取值去除后的结果。具体做法是通过切片操作将特征索引之前和之后的部分合并起来,形成新的

    样本。最后,将新样本添加到结果集合中。最后,函数返回结果集合(retDataSet),其中包含了

    所有符合条件的样本。

    3. 对连续变量划分数据集

    1. #对连续变量划分数据集,direction规定划分的方向,
    2. #决定是划分出小于value的数据样本还是大于value的数据样本集
    3. def splitContinuousDataSet(dataSet,axis,value,direction):
    4. retDataSet=[]
    5. for featVec in dataSet:
    6. if direction==0:
    7. if featVec[axis]>value:
    8. reducedFeatVec=featVec[:axis]
    9. reducedFeatVec.extend(featVec[axis+1:])
    10. retDataSet.append(reducedFeatVec)
    11. else:
    12. if featVec[axis]<=value:
    13. reducedFeatVec=featVec[:axis]
    14. reducedFeatVec.extend(featVec[axis+1:])
    15. retDataSet.append(reducedFeatVec)
    16. return retDataSet

    这是一个用于划分连续变量数据集的函数。它接受四个参数:dataSet(数据集),axis(要划分

    的特征的索引),value(划分的阈值),direction(划分的方向)。函数的作用是根据给定的方

    向和阈值,将数据集划分为两个子集。如果direction为0,则将大于阈值的样本划分到一个子集

    中;如果direction不为0,则将小于等于阈值的样本划分到一个子集中。

    在函数的实现中,通过遍历数据集中的每个样本,根据给定的方向和阈值进行划分。如果样本的特

    征值大于阈值且方向为0,将该样本的特征值从划分特征的位置上移除,并将剩余的特征值组成一

    个新的样本,添加到划分后的子集中。如果样本的特征值小于等于阈值且方向不为0,同样进行相

    同的操作。最后,返回划分后的子集。

    4. 选择划分方式

    1. #选择最好的数据集划分方式
    2. def chooseBestFeatureToSplit(dataSet,labels):
    3. numFeatures=len(dataSet[0])-1
    4. baseEntropy=calcShannonEnt(dataSet)
    5. bestInfoGain=0.0
    6. bestFeature=-1
    7. bestSplitDict={}
    8. for i in range(numFeatures):
    9. featList=[example[i] for example in dataSet]
    10. # print(featList)
    11. #对连续型特征进行处理
    12. if type(featList[0]).__name__=='float' or type(featList[0]).__name__=='int':
    13. #产生n-1个候选划分点
    14. sortfeatList=sorted(featList)
    15. splitList=[]
    16. for j in range(len(sortfeatList)-1):
    17. splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0)
    18. bestSplitEntropy=10000
    19. slen=len(splitList)
    20. #求用第j个候选划分点划分时,得到的信息熵,并记录最佳划分点
    21. for j in range(slen):
    22. value=splitList[j]
    23. newEntropy=0.0
    24. subDataSet0=splitContinuousDataSet(dataSet,i,value,0)
    25. subDataSet1=splitContinuousDataSet(dataSet,i,value,1)
    26. prob0=len(subDataSet0)/float(len(dataSet))
    27. newEntropy+=prob0*calcShannonEnt(subDataSet0)
    28. prob1=len(subDataSet1)/float(len(dataSet))
    29. newEntropy+=prob1*calcShannonEnt(subDataSet1)
    30. if newEntropy<bestSplitEntropy:
    31. bestSplitEntropy=newEntropy
    32. bestSplit=j
    33. #用字典记录当前特征的最佳划分点
    34. bestSplitDict[labels[i]]=splitList[bestSplit]
    35. infoGain=baseEntropy-bestSplitEntropy
    36. #对离散型特征进行处理
    37. else:
    38. uniqueVals=set(featList)
    39. newEntropy=0.0
    40. #计算该特征下每种划分的信息熵
    41. for value in uniqueVals:
    42. subDataSet=splitDataSet(dataSet,i,value)
    43. prob=len(subDataSet)/float(len(dataSet))
    44. print(prob)
    45. newEntropy+=prob*calcShannonEnt(subDataSet)
    46. infoGain=baseEntropy-newEntropy
    47. if infoGain>bestInfoGain:
    48. bestInfoGain=infoGain
    49. bestFeature=i
    50. #若当前节点的最佳划分特征为连续特征,则将其以之前记录的划分点为界进行二值化处理
    51. #即是否小于等于bestSplitValue
    52. if type(dataSet[0][bestFeature]).__name__=='float' or type(dataSet[0][bestFeature]).__name__=='int':
    53. bestSplitValue=bestSplitDict[labels[bestFeature]]
    54. labels[bestFeature]=labels[bestFeature]+'<='+str(bestSplitValue)
    55. for i in range(shape(dataSet)[0]):
    56. if dataSet[i][bestFeature]<=bestSplitValue:
    57. dataSet[i][bestFeature]=1
    58. else:
    59. dataSet[i][bestFeature]=0
    60. return bestFeature

    numFeatures=len(dataSet[0])-1:计算数据集中特征数量,减去1是因为最后一列通常是标签列。

    baseEntropy=calcShannonEnt(dataSet):计算整个数据集的基本熵。

    bestInfoGain=0.0:初始化最佳信息增益为0。bestFeature=-1:初始化最佳划分特征的索引为-1。

    bestSplitDict={}:创建一个空字典,用于记录连续特征的最佳划分点。

    遍历每个特征,featList=[example[i] for example in dataSet]:获取数据集中第i个特征所有取值。

    if type(featList[0]).__name__=='float' or ... :判断特征是否为连续型特征。

    sortfeatList=sorted(featList):对连续型特征的取值进行排序。

    splitList=[]:创建一个空列表,用于存储候选划分点。

    for j in range(len(sortfeatList)-1):遍历排序后的特征取值列表,生成n-1个候选划分点。

    splitList.append((sortfeatList[j]+sortfeatList[j+1])/2.0):将相邻特征值的平均值作为候选划分点。

    bestSplitEntropy=10000:初始化最佳划分点的信息熵为一个较大的值。

    slen=len(splitList):获取候选划分点的数量。for j in range(slen):遍历每个候选划分点。

    value=splitList[j]:获取当前候选划分点的值。newEntropy=0.0:初始化划分后的信息熵为0。

             subDataSet0=splitContinuousDataSet(dataSet,i,value,0):根据当前候选划分点将数据集划

    分为小于等于该值的子集。subDataSet1=splitContinuousDataSet(dataSet,i,value,1):根据当前候

    选划分点将数据集划分为大于该值的子集。

            prob0=len(subDataSet0)/float(len(dataSet)):计算小于等于划分点的子集在整个数据集中的

    概率。newEntropy+=prob0*calcShannonEnt(subDataSet0):计算小于等于划分点的子集的信息

    熵,并加权求和。prob1=len(subDataSet1)/float(len(dataSet)):计算大于划分点的子集在整个数

    据集中的概率。newEntropy+=prob1*calcShannonEnt(subDataSet1):计算大于划分点的子集的

    信息熵,并加权求和。

    if newEntropy

    bestSplitEntropy=newEntropy:更新最佳划分点的信息熵。

    bestSplit=j:记录当前最佳划分点的索引。

    bestSplitDict[labels[i]]=splitList[bestSplit]:用字典记录当前特征的最佳划分点。

    infoGain=baseEntropy-bestSplitEntropy:计算当前特征的信息增益。

           如果特征是离散型特征,uniqueVals=set(featList):获取特征的唯一取值。newEntropy=0.0:

    初始化划分后的信息熵为0。遍历每个离散特征取值。subDataSet=splitDataSet(dataSet,i,value):

    根据当前特征取值将数据集划分为子集。prob=len(subDataSet)/float(len(dataSet)):计算当前特征

    取值的概率。newEntropy+=prob*calcShannonEnt(subDataSet):计算当前特征取值的信息熵,并

    加权求和。infoGain=baseEntropy-newEntropy:计算当前特征的信息增益if infoGain >

    bestInfoGain:如果当前特征的信息增益大于当前最佳信息增益。bestInfoGain=infoGain:更新最

    佳信息增益。bestFeature=i:记录当前最佳划分特征的索引。

           如果当前最佳划分特征是连续型特征。bestSplitValue=bestSplitDict[labels[bestFeature]]:获

    取当前最佳划分特征的最佳划分点labels[bestFeature] = labels[bestFeature] + '<=' + str

    (bestSplitValue):将当前最佳划分特征的标签更新为带有最佳划分点的条件。遍历数据集中的每个

    样本。if dataSet[i][bestFeature]<=bestSplitValue:如果当前样本的最佳划分特征的取值小于等于

    最佳划分点。dataSet[i][bestFeature]=1:将当前样本的最佳划分特征的取值设置为1。如果当前样

    本的最佳划分特征的取值大于最佳划分点。dataSet[i][bestFeature]=0:将当前样本的最佳划分特

    征的取值设置为0。返回最佳划分特征的索引。

    5. 递归构造决策树

    1. #特征若已经划分完,节点下的样本还没有统一取值,则需要进行投票
    2. def majorityCnt(classList):
    3. classCount={}
    4. for vote in classList:
    5. if vote not in classCount.keys():
    6. classCount[vote]=0
    7. classCount[vote]+=1
    8. return max(classCount)
    9. #主程序,递归产生决策树
    10. def createTree(dataSet,labels,data_full,labels_full):
    11. classList=[example[-1] for example in dataSet]
    12. if classList.count(classList[0])==len(classList):
    13. return classList[0]
    14. if len(dataSet[0])==1:
    15. return majorityCnt(classList)
    16. bestFeat=chooseBestFeatureToSplit(dataSet,labels)
    17. bestFeatLabel=labels[bestFeat]
    18. myTree={bestFeatLabel:{}}
    19. featValues=[example[bestFeat] for example in dataSet]
    20. uniqueVals=set(featValues)
    21. if type(dataSet[0][bestFeat]).__name__=='str':
    22. currentlabel=labels_full.index(labels[bestFeat])
    23. featValuesFull=[example[currentlabel] for example in data_full]
    24. uniqueValsFull=set(featValuesFull)
    25. del(labels[bestFeat])
    26. #针对bestFeat的每个取值,划分出一个子树。
    27. for value in uniqueVals:
    28. subLabels=labels[:]
    29. if type(dataSet[0][bestFeat]).__name__=='str':
    30. uniqueValsFull.remove(value)
    31. myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels,data_full,labels_full)
    32. if type(dataSet[0][bestFeat]).__name__=='str':
    33. for value in uniqueValsFull:
    34. myTree[bestFeatLabel][value]=majorityCnt(classList)
    35. return myTree

    classList=[example[-1] for example in dataSet]:创建一个列表classList,其中包含数据集dataSet

    中每个样本的类别标签。

    if classList.count(classList[0])==len(classList):检查classList中的类别标签是否都相同。如果是,

    则返回该类别标签作为叶子节点的类别。

    if len(dataSet[0])==1:检查数据集dataSet是否只剩下一个特征。如果是,则返回classList中出现

    次数最多的类别标签作为叶子节点的类别。

    bestFeat=chooseBestFeatureToSplit(dataSet,labels):调用函数chooseBestFeatureToSplit,选择

    最佳的特征进行划分,并将其索引保存在bestFeat中。

    bestFeatLabel=labels[bestFeat]:根据bestFeat的索引,获取特征标签labels中对应的特征名称。

    myTree={bestFeatLabel:{}}:创建一个字典myTree,以bestFeatLabel作为键,空字典作为值。这

    个字典将用于构建决策树。

    featValues=[example[bestFeat] for example in dataSet]:创建一个列表featValues,其中包含数据

    集dataSet中每个样本在bestFeat特征上的取值。

    uniqueVals=set(featValues):将featValues转换为集合uniqueVals,以获取bestFeat特征的唯一取

    值。

    if type(dataSet[0][bestFeat]).__name__=='str':检查bestFeat特征的数据类型是否为字符串。

    如果是,则执行以下操作:

          currentlabel=labels_full.index(labels[bestFeat]):获取完整特征标签列表labels_full中labels

    [bestFeat]的索引,并将其保存在currentlabel中;          

          featValuesFull=[example[currentlabel] for example in data_full]:创建一个列表

    featValuesFull,其中包含完整数据集data_full中每个样本在currentlabel特征上的取值;         

          uniqueValsFull=set(featValuesFull):将featValuesFull转换为集合uniqueValsFull,以获取

    currentlabel特征的唯一取值。

    del(labels[bestFeat]):删除labels中索引为bestFeat的特征标签,因为该特征已经被用于划分。

    for value in uniqueVals:对于uniqueVals中的每个取值,执行以下操作:

           subLabels=labels[:]:创建一个新的特征标签列表subLabels,并将labels的值复制给它。

           if type(dataSet[0][bestFeat]).__name__=='str':如果bestFeat特征的数据类型为字符串,执行

    以下操作:uniqueValsFull.remove(value):从uniqueValsFull中移除当前取值value。 

            myTree[bestFeatLabel[value] =createTree(splitDataSet(dataSet,bestFeat,value),subLabels,

    data_ full,labels_full):递归调用createTree函数,传入划分后的子数据集、子特征标签列表以及完

    整数据集和特征标签列表,并将返回的子树存储在myTree中。

            if type(dataSet[0][bestFeat]).__name__=='str':如果bestFeat特征的数据类型为字符串,执行

    以下操作:for value in uniqueValsFull::对于uniqueValsFull中的每个取值,执行以下操作:

    myTree[bestFeatLabel][value]=majorityCnt(classList):将叶子节点的类别标签设置为classList中

    出现次数最多的类别标签。

    最后,返回构建好的决策树。

    1. df=pd.read_csv('watermelon_3a.csv')
    2. data=df.values[:,1:].tolist()
    3. data_full=data[:]
    4. labels=df.columns.values[1:-1].tolist()
    5. labels_full=labels[:]
    6. myTree=createTree(data,labels,data_full,labels_full)

    6. 画树 

    1. import matplotlib.pyplot as plt
    2. decisionNode=dict(boxstyle="sawtooth",fc="0.8")
    3. leafNode=dict(boxstyle="round4",fc="0.8")
    4. arrow_args=dict(arrowstyle="<-")
    5. #计算树的叶子节点数量
    6. def getNumLeafs(myTree):
    7. numLeafs=0
    8. firstStr=list(myTree.keys())[0]
    9. secondDict=myTree[firstStr]
    10. for key in secondDict.keys():
    11. if type(secondDict[key]).__name__=='dict':
    12. numLeafs+=getNumLeafs(secondDict[key])
    13. else: numLeafs+=1
    14. return numLeafs
    15. #计算树的最大深度
    16. def getTreeDepth(myTree):
    17. maxDepth=0
    18. firstStr=list(myTree.keys())[0]
    19. secondDict=myTree[firstStr]
    20. for key in secondDict.keys():
    21. if type(secondDict[key]).__name__=='dict':
    22. thisDepth=1+getTreeDepth(secondDict[key])
    23. else: thisDepth=1
    24. if thisDepth>maxDepth:
    25. maxDepth=thisDepth
    26. return maxDepth
    27. #画节点
    28. def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    29. createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    30. xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
    31. bbox=nodeType,arrowprops=arrow_args)
    32. #画箭头上的文字
    33. def plotMidText(cntrPt,parentPt,txtString):
    34. lens=len(txtString)
    35. xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
    36. yMid=(parentPt[1]+cntrPt[1])/2.0
    37. createPlot.ax1.text(xMid,yMid,txtString)
    38. def plotTree(myTree,parentPt,nodeTxt):
    39. numLeafs=getNumLeafs(myTree)
    40. depth=getTreeDepth(myTree)
    41. firstStr=list(myTree.keys())[0]
    42. cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    43. plotMidText(cntrPt,parentPt,nodeTxt)
    44. plotNode(firstStr,cntrPt,parentPt,decisionNode)
    45. secondDict=myTree[firstStr]
    46. plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
    47. for key in secondDict.keys():
    48. if type(secondDict[key]).__name__=='dict':
    49. plotTree(secondDict[key],cntrPt,str(key))
    50. else:
    51. plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
    52. plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
    53. plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    54. plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD
    55. def createPlot(inTree):
    56. fig=plt.figure(1,facecolor='white')
    57. fig.clf()
    58. axprops=dict(xticks=[],yticks=[])
    59. createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    60. plotTree.totalW=float(getNumLeafs(inTree))
    61. plotTree.totalD=float(getTreeDepth(inTree))
    62. plotTree.x0ff=-0.5/plotTree.totalW
    63. plotTree.y0ff=1.0
    64. plotTree(inTree,(0.5,1.0),'')
    65. plt.show()

    plotNode函数用于绘制节点。它接受节点文本(nodeTxt)、中心点(centerPt)、父节点(parentPt)和节

    点类型(nodeType)作为参数。在函数内部,它使用createPlot.ax1.annotate()函数来绘制节点文

    本。

    createPlot函数用于创建并显示一个图形。它接受一个树对象(inTree)作为参数。在函数内部,它创

    建了一个图形对象(fig),清除了图形对象中的内容,然后创建了一个子图对象(createPlot.ax1)。接

    下来,它调用了plotTree函数来绘制树的节点,并使用plt.show()显示图形。

    plotMidText函数用于在箭头上绘制文字。它接受三个参数:cntrPt表示箭头的中心点坐标,

    parentPt表示箭头的起始点坐标,txtString表示要绘制的文字。在函数内部,它计算了文字的位置

    坐标,并使用createPlot.ax1.text()函数在图形上绘制文字。

    plotTree函数用于绘制树的节点和箭头。它接受三个参数:myTree表示树的字典表示,parentPt表

    示父节点的坐标,nodeTxt表示节点的文本。在函数内部,它首先获取树的叶子节点数和深度,然

    后计算当前节点的位置坐标。接下来,它调用plotMidText函数在箭头上绘制文字,调用plotNode函

    数绘制节点。然后,它遍历树的子节点,如果子节点是字典类型,则递归调用plotTree函数绘制子

    树;如果子节点是叶子节点,则调用plotNode函数绘制叶子节点,并使用plotMidText函数在箭头上

    绘制文字。最后,它更新plotTree.y0ff的值,以便绘制下一层的节点。

    遇到的问题:createPlot.ax1 是什么意思?

    在这句代码中,createPlot是函数类型(function),而createPlot.ax1是一个

    matplotlib.axes._axes.Axes。createPlot.ax1是一个有效的变量名,而将其替换为

    createPlot_ax1会导致报错。在代码中,createPlot.ax1是一个全局变量,用于引用子图对象。

    功能有点类似于类的成员变量,为了共享createPlot.ax1。函数也是对象,给一个对象绑定一个属

    性就是这样的:函数对象本身就有很多属性,__name____doc__等等。自己绑定的要有意义,没

    意义的就不需要。

    1. def f():
    2. pass
    3. f.a = 1
    4. print(f.a)
    5. # 1
    createPlot(myTree)

  • 相关阅读:
    进程和线程
    代码随想录算法训练营刷题复习4 :单调栈
    Vue-报错No “exports“ main defined in xx
    Windows用户、组的管理
    一个简单的HTML网页 个人网站设计与实现 HTML+CSS+JavaScript自适应个人相册展示留言博客模板
    c# listbox
    setState到底是异步还是同步?
    函数式编程基本语法
    uni-app实现点击复制按钮 复制内容
    Bootstrap元素的边框样式和设置
  • 原文地址:https://blog.csdn.net/weixin_43961909/article/details/132787698