• 第3章 决策树


    决策树经常处理分类问题,近来的调查表明决策树也是经常使用的数据挖掘算法。
    决策树的流程图:
    长方形代表判断模块(decision block),椭圆形代表中止模块(terminating block),表示已经得出结论,可以中止运行。
    从判断模块引出左右箭头称作分支(branch),它可以到底另一个判断模块或者中止模块。
    决策树算法能够读取数据集合,构建决策树
    决策树的一个重要任务是为了数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,在这些机器根据数据集创建规则时,就是机器学习的过程。

    3.1 决策树的构造

    1. 优点:计算复杂度不高,输出结果易于理解,对中间值确实不敏感,可以处理不相关特征数据。
    2. 缺点:可能会产生过度匹配问题。
    3. 适用数据类型:数值型和标称型。

    创建分支的伪代码函数createBranch()如下所示:

    检测数据集中的每个子项是否属于同一分类:
        If so return 类标签;
        Else
            寻找划分数据集的最好特征
            划分数据集
            创建分支节点
                for 每个划分的子集
                    ## 递归调用createBranch
                    调用createBranch并增加返回结果到分支节点中
            return 分支节点
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    决策树的一般流程:

    1. 收集数据:可以使用任何方法。
    2. 准备数据:树构造只是用于标称型数据,因此数值型数据必须离散化。
    3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
    4. 训练算法:构造树的数据结构。
    5. 测试算法:使用经验树计算错误率。
    6. 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

    3.1.1 信息增益

    在划分数据集之前之后信息发生的变换称为信息增益。
    获得信息增益最高的特征就是最好的选择。

    熵:定义为信息的期望值
    如果待分类的事务划分在多个分类之中,在符号 x i x_i xi的信息定义为:
    l ( x i ) = − log ⁡ 2 p ( x i ) l(x_i) = - \log _{2} p(x_i) l(xi)=log2p(xi)
    为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:
    H = − ∑ i = 1 n p ( x i ) log ⁡ 2 p ( x i ) H = - \sum _{i = 1}^{n} p(x_i) \log _{2} p(x_i) H=i=1np(xi)log2p(xi)

    计算给定数据集的香农熵

        from math import log
    
        def calcShannonEnt(dataSet):
            numEntries = len(dataSet)
            labelCount = {}
    
            ##为所有可能分类创建字典
            for featVec in dataSet:
                currentLabel = featVec[-1]
                if currentLabel not in labelCount.keys():
                    labelCount[currentLabel] = 0
                labelCount[currentLabel] += 1
            shannonEnt = 0.0
    
    
            for key in labelCount:
                prob = float(labelCount[key]) / numEntries
                shannonEnt -= prob * log(prob, 2)       ##以2为底数求对数
            return shannonEnt
        ##简单鱼鉴定数据集
        def creatDataSet():
            dateSet = [[1, 1, 'yes'],
                    [1, 1, 'yes'],
                    [1, 0, 'no'],
                    [0, 1, 'no'],
                    [0, 1, 'no']]
    
            labels = ['no surfacing', 'flippers']
            return dateSet, labels
    
        if __name__ == '__main__':
            dataSet, labels = creatDataSet()
            shannonEnt = calcShannonEnt(dataSet)
            print(shannonEnt)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    3.1.2 划分数据集

    按照给定特征划分数据集

        def splitDataSet(dataSet, axis, value):
            '''
    
            :param dataSet: 待划分的数据集
            :param axis: 划分数据集的特征
            :param value: 需要返回的特征的值
            :return:  结果数据集
            '''
            retDataSet = []     ##创建一个新的list对象
            for featVec in dataSet:
                if featVec[axis] == value:
                    ##抽取
                    reducedFeatVec = featVec[:axis] 
                    reducedFeatVec.extend(featVec[axis + 1:])
                    retDataSet.append(reducedFeatVec)
            return retDataSet
        
        if __name__ == '__main__':
            dataSet, labels = creatDataSet()
            retDataSet = splitDataSet(dataSet, 0, 1)
            print(retDataSet)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    Python列表对象的方法,append()和extend()用于在列表末尾添加元素,但它们的用法和效果有所不同。

    1. append()方法:
      用法:list.append(element)
      参数:element,要添加的单个元素。
      功能:将指定的元素添加到列表的末尾。
      结果:在列表的末尾添加一个新的元素,并扩展列表的长度。
      特点:可以添加任意类型的元素,包括可迭代对象(如列表)。当添加可迭代对象时,整个对象作为单个元素添加到列表中。
    2. extend()方法:
      用法:list.extend(iterable)
      参数:iterable,一个可迭代对象,如列表、元组、字符串等。
      功能:将可迭代对象中的每个元素添加到列表的末尾。
      结果:不会创建一个新的列表对象,而是在原列表的末尾追加元素。
      特点:只能添加可迭代对象的元素值,而不是整个可迭代对象。当添加的是另一个列表时,extend()会将那个列表中的每个元素逐个添加到原列表中。
    3. 总结:
      append()适用于添加单个元素或可迭代对象,但整个可迭代对象被视为单个元素添加。
      extend()适用于添加可迭代对象的每个元素,而不是整个可迭代对象。
      这些方法都不返回新列表,而是直接修改原列表。

    选择最好的数据集划分方式

        def chooseBestFeatureToSplit(dataSet):
            numFeatures = len(dataSet[0]) - 1
            baseEntropy = calcShannonEnt(dataSet)
            bestInfoGain = 0.0
            bestFeature = -1
    
            for i in range(numFeatures):
                ## 创建唯一的分类标签列表
                featList = [example[i] for example in dataSet]
                uniqueVals = set(featList)
                newEntropy = 0.0
                ## 计算每种划分方式的信息熵
                for value in uniqueVals:
                    subDataSet = splitDataSet(dataSet, i, value)
                    prob = len(subDataSet) / float(len(dataSet))
                    newEntropy += prob * calcShannonEnt(subDataSet)
                infoGain = baseEntropy - newEntropy
                ## 计算最好的信息增益
                if (infoGain > bestInfoGain) :
                    bestInfoGain = infoGain
                    bestFeature = i
            return bestFeature
    
        if __name__ == '__main__':
            dataSet, labels = creatDataSet()
            feature = chooseBestFeatureToSplit(dataSet)
            print(feature)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    3.1.3 递归构建决策树

    其工作原理如下:
    创建原始数据集,然后基于最好的属性值划分数据集合,由于特征值可能多余两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将向下递归到树分支的下一个节点,在这个节点上我们再次划分数据。因此我们可以采用递归的原则处理数据集。
    递归结束的条件:
    程序遍历完所有划分数据集的性质,或者每个分支下的所有实例都有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

    	import operator
    	
    	def majorityCnt(classList):
    	    '''
    	
    	    :param classList: 分类名称列表
    	    :return: 出现次数最多的分类标签
    	    '''
    	    classCount = {}
    	    for vote in classList:
    	        if vote not in classCount.keys(): classCount[vote] = 0
    	        classCount[vote] += 1
    	    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    	    return sortedClassCount[0][0]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    创建树的函数代码:

    	def createTree(dataSet, labels):
    	    '''
    	
    	    :param dataSet: 数据集
    	    :param labels: 标签列
    	    :return:
    	    '''
    	    classList = [example[-1] for example in dataSet]
    	    ## 类别完全相同则停止继续划分
    	    if classList.count(classList[0]) == len(classList):
    	        return classList[0]
    	    ## 遍历完所有特征值时返回出现次数最多的
    	    if len(dataSet[0]) == 1 :
    	        return majorityCnt(classList)
    	    bestFeat = chooseBestFeatureToSplit(dataSet)
    	    bestFeatLabel = labels[bestFeat]
    	    myTree = {bestFeatLabel: {}}
    	    ##得到列表包含的所有属性值
    	    del(labels[bestFeat])
    	    featValues = [example[bestFeat] for example in dataSet]
    	    uniqueVals = set(featValues)
    	    for value in uniqueVals:
    	        subLabels = labels[:]
    	        myTree[bestFeatLabel][value]= createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    	    return myTree
    	
    	if __name__ == '__main__':
    	    myDat, labels = creatDataSet()
    	    myTree = createTree(myDat, labels)
    	    print(myTree)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    3.2 在Python中使用Matplotlib注解绘制树形图

    3.2.1 Matplotlib注解

    Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。

    使用文本注解绘制树节点

        import matplotlib
        matplotlib.use('TkAgg')
        import matplotlib.pyplot as plt
        ## 定义文本框的箭头格式
        decisionNode = dict(boxstyle="sawtooth", fc="0.8")
        leafNode = dict(boxstyle="round4", fc="0.8")
        arrow_args = dict(arrowstyle="<-")
    
        def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    
            '''
            绘制带箭头的注解
            :param nodeTxt:
            :param centerPt:
            :param parentPt:
            :param nodeType:
            :return:
            '''
            createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                                    xytext=centerPt, textcoords='axes fraction',
                                    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    
        def createPlot():
            fig = plt.figure(1, facecolor='white')
            fig.clf()
            createPlot.ax1 = plt.subplot(111, frameon= False)
            plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
            plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
    
            plt.show()
    
    
        if __name__ == '__main__':
            createPlot();
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    3.2.2 构造注解树

    获得叶节点的数目和树的层数

        def getNumLeafs(myTree):
            numLeafs = 0;
            firstStr = list(myTree)[0]
            secondDict = myTree[firstStr]
            for key in secondDict.keys():
                ## 测试节点的数据类型是否为字典
                if type(secondDict[key]).__name__=='dict':
                    numLeafs += getNumLeafs(secondDict[key])
                else:
                    numLeafs += 1
            return numLeafs
    
        def getTreeDepth(myTree):
            maxDepth = 0
            firstStr = list(myTree)[0]
            secondDict = myTree[firstStr]
            for key in secondDict.keys():
                if type(secondDict[key]).__name__ == 'dict':
                    thisDepth = 1 + getTreeDepth(secondDict[key])
                else:
                    thisDepth = 1
                if thisDepth > maxDepth:
                    maxDepth = thisDepth
            return maxDepth
    
        def retrieveTree(i):
            listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                        {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                        ]
            return listOfTrees[i]
    
    
        if __name__ == '__main__':
            listOfTrees = retrieveTree(1)
            print(listOfTrees)
            myTree = retrieveTree(0)
            print(getNumLeafs(myTree))
            print(getTreeDepth(myTree))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    plotTree函数

        def plotTree(myTree, parentPt, nodeTxt):
            ## 计算树的宽与高
            numLeafs = getNumLeafs(myTree)
            depth = getTreeDepth(myTree)
            firstStr = list(myTree)[0]
            cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
            ## 标记子节点属性值
            plotMidText(cntrPt, parentPt, nodeTxt)
            plotNode(firstStr, cntrPt, parentPt, decisionNode)
            secondDict = myTree[firstStr]
            ## 减少y偏移值
            plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
            for key in secondDict.keys():
                if type(secondDict[key]).__name__ == 'dict':
                    plotTree(secondDict[key], cntrPt, str(key))
                else:
                    plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
                    plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                    plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
            plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
        
        ##更新后的createPlot
        def createPlot(inTree):
            fig = plt.figure(1, facecolor='white')
            fig.clf()
            # createPlot.ax1 = plt.subplot(111, frameon= False)
            # plotNode(U'Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
            # plotNode(U'Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
            axprops = dict(xticks=[], yticks=[])
            createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
            plotTree.totalW = float(getNumLeafs(inTree))
            plotTree.totalD = float(getTreeDepth(inTree))
            plotTree.xOff = -0.5 / plotTree.totalW
            plotTree.yOff = 1.0
            plotTree(inTree, (0.5, 1.0), '')
            plt.show()
    
        if __name__ == '__main__':
            myTree = retrieveTree(0)
            createPlot(myTree)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    3.3 测试和存储分类器

    3.3.1 测试算法:使用决策树执行分类

    使用决策树的分类函数

        def classify(inputTree, featLabels, testVec):
            firstStr = list(inputTree)[0]
            secondDict = inputTree[firstStr]
            ## 将标签字符串转换为索引
            featIndex = featLabels.index(firstStr)
            for key in secondDict.keys():
                if testVec[featIndex] == key:
                    if type(secondDict[key]).__name__ == 'dict':
                        classLabel = classify(secondDict[key], featLabels, testVec)
                    else:
                        classLabel = secondDict[key]
            return classLabel
    
        if __name__ == '__main__':
            myDat, labels = creatDataSet()
            print(labels)
            myTree = treePlotter.retrieveTree(0)
            print(myTree)
            print(classify(myTree, labels, [1, 0]))
            print(classify(myTree, labels, [1, 1]))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    3.3.2 使用算法:决策树的存储

    使用pickle模块存储决策树:

        def storeTree(inputTree, filename):
            import pickle
            fw = open(filename, 'wb')
            pickle.dump(inputTree, fw)
            fw.close()
    
        def grabTree(filename):
            import pickle
            fr = open(filename, 'rb')
            return pickle.load(fr)
    
        if __name__ == '__main__':
            myDat, labels = creatDataSet()
            print(labels)
            myTree = treePlotter.retrieveTree(0)
            storeTree(myTree, './resource/classifierStorage.txt')
            classifyTree = grabTree('./resource/classifierStorage.txt')
            print(classifyTree)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    3.4 示例:使用决策树预测隐形眼镜类型

        def use():
            fr = open('./resource/lenses.txt', 'r')
            lenses = [inst.strip().split('\t') for inst in fr.readlines()]
            lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
            lensesTree = createTree(lenses, lensesLabels)
            print(lensesTree)
            treePlotter.createPlot(lensesTree)
    
        if __name__ == '__main__':
            use()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    决策树非常好地匹配了实验数据,然而这些匹配选项可能太多了。我们将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,我们可以裁剪决策树,去掉一些不必要地叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中。

  • 相关阅读:
    win | wireshark | 在win上跑lua脚本 解析数据包
    Diffusion Models & CLIP
    linux进阶(脚本编程/软件安装/进程进阶/系统相关)
    TCP/IP、DTN网络通信协议族
    稳定性实践:限流降级
    博客更新计划的说明
    百趣代谢组学资讯:机制探索不发愁,浅看真菌防治靶点代谢组学研究思路
    PHP调试工具 - FirePHP安装与使用方法
    Fe3+-多巴胺修饰Pluronic的多功能性水凝胶/多巴胺修饰聚丙烯微孔膜表面的研究
    JAVA多线程
  • 原文地址:https://blog.csdn.net/zhangsandidi/article/details/138012884