• 机器学习——决策树


    前言

            跟着b站补基础,视频链接:第一章:决策树原理 1-决策树算法概述_哔哩哔哩_bilibili


    一、原理篇

    1、树模型

             · 决策树:从根节点开始一步步走到叶子节点(决策 )。

            · 所有的数据最终都会落到叶子节点,既可以做分类也可以做回归。

            如上图根据不同的特征:年龄以及性别进行决策 ,5个人(数据)落到了三个叶子节点。

            分类问题是从不同类型的数据中学习到这些数据间的边界,比如通过鱼的体长、重量、鱼鳞色泽等维度来分类鲶鱼和鲤鱼,这是一个定性问题

            回归问题则是从同一类型的数据中学习到这种数据中不同维度间的规律,去拟合真实规律,比如通过数据学习到面积、房间数、房价几个维度的关系,用于根据面积和房间数预测房价,这是一个定量问题。

    [以上分类与回归问题的文章:ML科普系列(二)分类与回归 - 北岛知寒 - 博客园 (cnblogs.com)]

            在决策树当中,选择特征进行决策的顺序是很重要的,不同的决策顺序出来的结果可能会受影响而不同,所以把握决策的顺序是很重要的。 

    2、树的组成

            · 根节点:第一个选择点(没有前驱的节点,上图中的根节点就是“age<15”)

            · 非叶子节点与分支:中间过程

            · 叶子节点:最终的决策结果(没有后继的节点,上图的叶子节点就是最后的三个⚪)

    3、决策树的训练与测试

            · 训练阶段:从给定的训练集构造出来一棵树(从根节点开始选择特征,如何进行特征切分)

            · 测试阶段:根据构造出来的树模型从上到下走一遍就好了

            这里的难点在于如何进行训练,特征的选择顺序应该怎么排。

            对于根节点的选择,我们的目标应该是根节点就像一个老大似的能更好的切分数据(分类的效果更好),根节点下面的节点自然就是二当家了,以此类推,数据能很快就完成了分类。

            所以这就相当于,通过一种衡量标准,来计算通过不同特征进行分支选择后的分类情况,找出来最好的那个当成根节点,以此类推。

    4、衡量标准-熵

            熵:表示随机变量不确定的度量(说白了就是混乱程度,商场买的类别越多就越混乱,专卖店只买一种类别就越稳定)

            公式:

            pi表示概率,我们都知道属于某一类的概率只会在[0,1],使用log的时候就在(-∞,0]区间内,前面有-号,也就是[0,∞),再乘以pi,最终结果还是正值,然后将所有的情况进行累加。

            举个例子:A集合[1,1,1,1,1,1,1,2,2]、B集合[1,2,3,4,5,6,7,8,9]

    显然A的类别少熵值小,B的类别多熵值大。

    1. import math
    2. def calculate_entropy(data):
    3. frequency = {} # 计算每个元素出现的次数
    4. for item in data:
    5. if item in frequency:
    6. frequency[item] += 1
    7. else:
    8. frequency[item] = 1
    9. total_elements = len(data) # 计算总元素数量
    10. entropy = 0 # 计算熵
    11. for item, count in frequency.items():
    12. p = count / total_elements # 计算当前元素出现的概率
    13. entropy -= p * math.log2(p) # 累加熵
    14. return entropy
    15. A = [1, 1, 1, 1, 1, 1, 1, 2, 2]
    16. B = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    17. entropy_A = calculate_entropy(A)
    18. print(f"集合A的熵为: {entropy_A}")
    19. entropy_B = calculate_entropy(B)
    20. print(f"集合B的熵为: {entropy_B}")

            因此可以知道,我们要想分类效果越好,就要熵值越小越好。也就是说,接下来要做的是让分类更好,就是让熵值越来越小,也就是要让熵值下降,且选择熵值下降越多的。

            计算可知,当p=0.5的时候,H(p)=1,此时的随机变量不确定值最大;p=0或p=1时,H(p)=0,此时的随机变量不确定值最小。

            在选择决策节点的时候,要考虑信息增益:表示特征X使得类Y的不确定性减少的程度。 

    5、决策树构造实例

            有以下数据,一共14条数据、4个特征。

            在进行决策选择的时候,根据不同的特征有不同的结果:

             这时候就可以计算信息增益,来选择决策节点。在历史数据中,14天有9天打球,据此计算此时的熵值:

            4个特征逐一分析:

            根据数据统计,outlook取值分别为sunny、overcast、rainy的概率分别为:5/14、4/14、5/14
            熵值计算:5/14*0.971 + 4/14*0 + 5/14*0.971 = 0.693
            信息增益:系统的熵值从原始的0.940下降到了0.693,增益为0.247
            同样的方式可以计算出其他特征的信息增益,选择增益最多的就好了。

            计算得到:gain(temperature)=0.029         gain(humidity)=0.152         gain(windy)=0.048 

            所以最终选的的是outlook该特征进行决策。接下来就继续再现在分出来的3个类种,继续分别进行决策划分,选择二当家,以此类推。

    6、ID3、C4.5、CART

    1)特征选择准则

    • ID3:使用信息增益作为选择特征的准则。信息增益是类别信息熵与某个属性状态下不同特征的信息熵(条件概率)的差值,它衡量了一个特征对于分类结果的影响程度。然而,ID3算法倾向于选择取值较多的特征,这可能会导致过拟合。
    • C4.5:在ID3的基础上进行了改进,使用信息增益比作为选择特征的准则。信息增益比通过引入一个惩罚项(特征的固有值),来克服ID3算法中信息增益偏向选择取值较多特征的不足。
    • CART:对于分类树,CART使用基尼指数作为选择特征的准则。基尼指数反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。基尼指数越小,说明数据集D的纯度越高。CART算法总是将当前样本集分割为两个子样本集,使得生成的决策树的每个非叶结点都只有两个分枝。

    2)树的结构

    • ID3和C4.5:生成的决策树可能包含多叉树结构,即每个内部节点可能对应多个分支。
    • CART:生成的决策树是二叉树结构,即每次分裂只产生两个子节点。这使得CART算法生成的决策树结构更为简洁。

    3)剪枝策略

    • ID3:原始的ID3算法并没有明确的剪枝策略,这可能导致生成的决策树过拟合。但在实际应用中,通常会结合剪枝策略来提高模型的泛化能力。
    • C4.5:在树构造过程中进行剪枝,通过预剪枝或后剪枝来减少模型的复杂度,防止过拟合。C4.5还提供了对连续属性的离散化处理,以及对不完整数据的处理能力。
    • CART:同样需要进行剪枝来防止过拟合。CART剪枝分为两部分:生成子树序列和交叉验证。通过选择最优的子树来平衡模型的复杂度和预测性能。

    4)其他差异

    • 处理数据类型:ID3和C4.5可以同时处理标称型和数值型数据,而CART在分类任务中主要针对标称型数据进行处理。
    • 应用场景:ID3和C4.5主要用于分类任务,而CART既可以用于分类任务也可以用于回归任务(当CART用作回归树时,使用平方误差作为划分准则)。

    注:这部分是AI生成的,参考的文章:决策树三种算法比较(ID3、C4.5、CART)_三种决策树算法的区别-CSDN博客

    ID3、C4.5、CART三种决策树的区别_id3决策树和c4.5决策树的区别-CSDN博客

    7、离散值与连续值

            对于离散值,也就是类似于前面这种,可以根据特征进行选择,对于连续值也是类似的,选择一个值将数据进行离散化就好了。

    8、剪枝策略

             决策树随着划分的越来越细,在训练的时候效果会越来越好,但是在别的数据集上的泛化能力可能比较差效果不好,这就是过拟合情况,为了减少过拟合程度,因此要进行剪枝操作。

    (1)预剪枝:

            预剪枝是边建立决策树边进行剪枝的操作。比如限制深度、叶子节点个数、叶子节点样本数、信息增益量等。

    (2)后剪枝:

            后剪枝是当建立完决策树后来进行剪枝操作。需要通过一定的衡量标准。C(T)是基尼系数,α是自己设置的,T_leaf是叶子节点数

            如下是一个决策树图,X[]表示对应数据集种的哪个特征,gini是基尼系数,samples是当前节点所有样本数量,value是不同类别的数量,例如第一个框中,[49,50,50]表示的是a类有49个b类有50个c类有50个。

             基尼系数表示在全部居民收入中,用于进行不平均分配的那部分收入占总收入的百分比。社会中每个人的收入都一样、收入分配绝对平均时,基尼系数是 0; 全社会的收入都集中于一个人、收入分配绝对不平均时,基尼系数是 1。现实生活中,两种情况都不可能发生,基尼系数的实际数值只能介于 0 ~ 1 之间。【参考:什么是基尼系数 - 国家统计局 (stats.gov.cn)

            例如上图第一个节点的左孩子节点,gini为0,因为它所有数据都属于a类,此时达到了纯化。

             某个节点要不要分裂,需要按照公式计算,如果新的C值比较大,就说明损失严重,需要分裂。

    9、回归问题

            对于分类问题可以根据熵值进行划分,而回归问题则根据方差结果(离散程度),预测结果是取该节点中所有数的平均数。

    二、代码篇

    1、基于sklearn实现

            这里使用的是iris数据集

    1. from sklearn.datasets import load_iris
    2. from sklearn.model_selection import train_test_split
    3. from sklearn.tree import DecisionTreeClassifier
    4. from sklearn import tree
    5. import matplotlib.pyplot as plt
    6. iris = load_iris() # 加载示例数据集
    7. X = iris.data # 特征
    8. y = iris.target # 标签
    9. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    10. # 创建决策树分类器
    11. clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
    12. clf.fit(X_train, y_train) # 训练模型
    13. y_pred = clf.predict(X_test) # 预测
    14. accuracy = clf.score(X_test, y_test) # 模型准确率
    15. print(f"模型准确率: {accuracy:.2f}")
    16. plt.figure(figsize=(6, 4)) # 可视化决策树
    17. tree.plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
    18. plt.show()

             生成的决策树如下所示:

            输出结果:

    模型准确率: 0.98 

    2、基于python实现

    1. import matplotlib.pyplot as plt
    2. from math import log
    3. import operator
    4. # part1:定义数据集
    5. def creatDataSet():
    6. DataSet = [
    7. [0, 0, 0, 0, 'no'],
    8. [0, 0, 0, 1, 'no'],
    9. [0, 1, 0, 1, 'yes'],
    10. [0, 1, 1, 0, 'yes'],
    11. [0, 0, 0, 0, 'no'],
    12. [1, 0, 0, 0, 'no'],
    13. [1, 1, 1, 1, 'yes'],
    14. [1, 0, 1, 2, 'yes'],
    15. [1, 0, 1, 2, 'yes'],
    16. [2, 0, 1, 2, 'yes'],
    17. [2, 0, 1, 1, 'yes'],
    18. [2, 1, 0, 1, 'yes'],
    19. [2, 1, 0, 2, 'yes'],
    20. [2, 0, 0, 0, 'no']]
    21. labels = ['F1-AGE', 'F2-WORK', 'F3-HOME', 'F4-LOAN']
    22. return DataSet, labels
    23. # part2:创建树
    24. def createTree(DataSet, labels, featLabels):
    25. # 其实就是labels-classList:['no', 'no', 'yes', 'yes', 'no', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'no']
    26. classList = [example[-1] for example in DataSet]
    27. # 判断是否只属于一类,所有数据均属于同一类则返回该类别
    28. if classList.count(classList[0]) == len(classList):
    29. return classList[0]
    30. # 判断特征是否删除完毕,返回最多的类
    31. if len(DataSet[0]) == 1:
    32. return majorityCnt(classList)
    33. # 得到最佳收益的特征
    34. bestFeat = chooseBestFeatureTosplit(DataSet)
    35. # 得到最佳收益的label
    36. bestFeatLabel = labels[bestFeat]
    37. # 添加最佳label
    38. featLabels.append(bestFeatLabel)
    39. myTree = {bestFeatLabel: {}}
    40. # 删除特征
    41. del(labels[bestFeat])
    42. # 得到最佳特征的所有值
    43. featValue = [example[bestFeat] for example in DataSet]
    44. uniqueVals = set(featValue)
    45. # 自调用创建树
    46. for value in uniqueVals:
    47. sublabels = labels[:]
    48. myTree[bestFeatLabel][value] = createTree(splitDataSet(DataSet, bestFeat, value), sublabels, featLabels)
    49. return myTree
    50. # 决策完成之后,统计类别最多的类
    51. def majorityCnt(classList):
    52. classCount = {}
    53. for vote in classList:
    54. if vote not in classCount.keys():
    55. classCount[vote] = 0
    56. classCount[vote] += 1
    57. sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    58. return sortedClassCount[0][0]
    59. # 得到使得增益最高的特征
    60. def chooseBestFeatureTosplit(DataSet):
    61. numFeatures = len(DataSet[0]) - 1 # 统计特征个数
    62. baseEntropy = calcshannonEnt(DataSet) # 计算最右边特征的熵值,也就是还没进行决策的初始熵
    63. bestInfoGain = 0.0
    64. bestFeature = -1
    65. for i in range(numFeatures):
    66. featList = [example[i] for example in DataSet] # 存储第i个特征的所有数据
    67. uniqueVals = set(featList) # 只保留唯一数据
    68. newEntropy = 0.0
    69. for value in uniqueVals: # 计算根据某标签进行划分后的熵
    70. subDataSet = splitDataSet(DataSet, i, value)
    71. prob = len(subDataSet) / float(len(DataSet))
    72. newEntropy += prob * calcshannonEnt(subDataSet)
    73. infoGain = baseEntropy - newEntropy # 计算增益
    74. if infoGain > bestInfoGain: # 判断是否是最佳特征及最佳增益
    75. bestInfoGain = infoGain
    76. bestFeature = i
    77. return bestFeature
    78. # 按照第i个特征划分不同值的数量(value在前面是循环的,所以会统计到每一类),并删掉该特征
    79. def splitDataSet(DataSet, axis, value):
    80. retDataSet = []
    81. for featVec in DataSet:
    82. if featVec[axis] == value:
    83. reducedFeatVec = featVec[:axis]
    84. reducedFeatVec.extend(featVec[axis + 1:])
    85. retDataSet.append(reducedFeatVec)
    86. return retDataSet
    87. # 计算某个特征的熵值
    88. def calcshannonEnt(DataSet):
    89. numExamples = len(DataSet) # 统计数据量
    90. labelCounts = {}
    91. for featVec in DataSet: # labelCounts:存储最后一个特征的类别及数量
    92. currentLabel = featVec[-1]
    93. if currentLabel not in labelCounts.keys():
    94. labelCounts[currentLabel] = 0
    95. labelCounts[currentLabel] += 1
    96. shannonEnt = 0.0
    97. for key in labelCounts: # 计算某个特征的熵值
    98. prob = float(labelCounts[key]) / numExamples
    99. shannonEnt -= prob * log(prob, 2)
    100. return shannonEnt
    101. # 定义节点的格式
    102. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    103. leafNode = dict(boxstyle="round4", fc="0.8")
    104. arrow_args = dict(arrowstyle="<-")
    105. def plotTree(myTree, parentPt, nodeTxt):
    106. # 如果 myTree 是一个叶节点,直接返回
    107. if type(myTree).__name__ != 'dict':
    108. plotNode(myTree, parentPt, parentPt, leafNode)
    109. return
    110. numLeafs = getNumLeafs(myTree)
    111. depth = getTreeDepth(myTree)
    112. firstStr = list(myTree.keys())[0]
    113. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    114. plotMidText(cntrPt, parentPt, nodeTxt)
    115. plotNode(firstStr, cntrPt, parentPt, decisionNode)
    116. secondDict = myTree[firstStr]
    117. plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    118. for key in secondDict.keys():
    119. if type(secondDict[key]).__name__ == 'dict':
    120. plotTree(secondDict[key], cntrPt, str(key))
    121. else:
    122. plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
    123. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    124. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    125. plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    126. def plotMidText(cntrPt, parentPt, txtString):
    127. xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    128. yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    129. plt.text(xMid, yMid, txtString)
    130. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    131. plt.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
    132. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    133. def createPlot(inTree):
    134. fig = plt.figure(1, facecolor='white')
    135. fig.clf()
    136. axprops = dict(xticks=[], yticks=[])
    137. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    138. plotTree.totalW = float(getNumLeafs(inTree))
    139. plotTree.totalD = float(getTreeDepth(inTree))
    140. plotTree.xOff = -0.5 / plotTree.totalW
    141. plotTree.yOff = 1.0
    142. plotTree(inTree, (0.5, 1.0), '')
    143. plt.show()
    144. def getNumLeafs(myTree):
    145. numLeafs = 0
    146. # 如果 myTree 是一个叶子节点,直接返回1
    147. if type(myTree).__name__ != 'dict':
    148. return 1
    149. firstStr = list(myTree.keys())[0]
    150. secondDict = myTree[firstStr]
    151. for key in secondDict.keys():
    152. if type(secondDict[key]).__name__ == 'dict':
    153. numLeafs += getNumLeafs(secondDict[key])
    154. else:
    155. numLeafs += 1
    156. return numLeafs
    157. def getTreeDepth(myTree):
    158. maxDepth = 0
    159. # 如果 myTree 是一个叶子节点,深度为1
    160. if type(myTree).__name__ != 'dict':
    161. return 1
    162. firstStr = list(myTree.keys())[0]
    163. secondDict = myTree[firstStr]
    164. for key in secondDict.keys():
    165. if type(secondDict[key]).__name__ == 'dict':
    166. thisDepth = 1 + getTreeDepth(secondDict[key])
    167. else:
    168. thisDepth = 1
    169. if thisDepth > maxDepth:
    170. maxDepth = thisDepth
    171. return maxDepth
    172. if __name__ == '__main__':
    173. DataSet, labels = creatDataSet()
    174. featLabels = []
    175. myTree = createTree(DataSet, labels, featLabels)
    176. createPlot(myTree)

             输出结果:

  • 相关阅读:
    1700*D. Flowers(DP&&前缀和&&预处理打表)
    【R语言简介】
    连花清瘟卖断货?近一个月解热药价格暴涨33%,销额超206万元
    ARM---day02
    第三章 SpringBoot构造流程源码分析
    代码随想录算法训练营Day55 (Day 54休息) | 动态规划(15/17) LeetCode 392.判断子序列 115.不同的子序列
    【2021届】数据结构期末实验考试
    java计算机毕业设计火车订票管理系统MyBatis+系统+LW文档+源码+调试部署
    PostgreSQL修炼之道笔记之基础篇(十一)
    【过程记录】ArcGIS Pro打开.osgb文件
  • 原文地址:https://blog.csdn.net/weixin_62472350/article/details/141057803