• 决策树实验分析(分类和回归任务,剪枝,数据对决策树影响)


    目录

    1. 前言

    2. 实验分析

            2.1 导入包

            2.2 决策树模型构建及树模型的可视化展示

            2.3 概率估计

            2.4 绘制决策边界

            2.5 决策树的正则化(剪枝)

            2.6 对数据敏感

            2.7 回归任务

            2.8 对比树的深度对结果的影响

            2.9 剪枝


    1. 前言

            本文主要分析了决策树的分类和回归任务,对比一系列的剪枝的策略对结果的影响,数据对于决策树结果的影响。

            介绍使用graphaviz这个决策树可视化工具

    2. 实验分析

            2.1 导入包

    1. #1.导入包
    2. import os
    3. import numpy as np
    4. import matplotlib
    5. %matplotlib inline
    6. import matplotlib.pyplot as plt
    7. plt.rcParams['axes.labelsize'] = 14
    8. plt.rcParams['xtick.labelsize'] = 12
    9. plt.rcParams['ytick.labelsize'] = 12
    10. import warnings
    11. warnings.filterwarnings('ignore')

            2.2 决策树模型构建及树模型的可视化展示

            下载安装包:https://graphviz.gitlab.io/_pages/Download/Download_windows.html

             选择一款安装,注意安装时要配置环境变量

            注意这里使用的是鸢尾花数据集,选择花瓣长和宽两个特征

    1. #2.建立树模型
    2. from sklearn.datasets import load_iris
    3. from sklearn.tree import DecisionTreeClassifier
    4. iris = load_iris()
    5. X = iris.data[:,2:] # petal legth and width
    6. y = iris.target
    7. tree_clf = DecisionTreeClassifier(max_depth=2)
    8. tree_clf.fit(X,y)
    1. #3.树模型的可视化展示
    2. from sklearn.tree import export_graphviz
    3. export_graphviz(
    4. tree_clf,
    5. out_file='iris_tree.dot',
    6. feature_names=iris.feature_names[2:],
    7. class_names=iris.target_names,
    8. rounded=True,
    9. filled=True
    10. )

            然后就可以使用graphviz包中的dot.命令工具将此文件转换为各种格式的如pdf,png,如 dot -Tpng iris_tree.png -o iris_tree.png

            可以去文件系统查看,也可以用python展示

    1. from IPython.display import Image
    2. Image(filename='iris_tree.png',width=400,height=400)

            分析:value表示每个节点所有样本中各个类别的样本数,用花瓣宽<=0.8和<=1.75 作为根节点划分,叶子节点表示分类结果,结果执行少数服从多数策略,gini指数随着分类进行在减小。

            2.3 概率估计

            估计类概率 输入数据为:花瓣长5厘米,宽1.5厘米的花。相应节点是深度为2的左节点,因此决策树因输出以下概率:

            iris-Setosa为0%(0/54)

            iris-Versicolor为90.7%(49/54)

            iris-Virginica为9.3%(5/54)

            

    1. #4.概率估计
    2. print(tree_clf.predict_proba([[5,1.5]]))
    3. print(tree_clf.predict([[5,1.5]]))

            2.4 绘制决策边界

            

    1. #5.绘制决策边界
    2. from matplotlib.colors import ListedColormap
    3. def plot_decision_boundary(clf,X,y,axes=[0,7.5,0,3],iris=True,legend=False,plot_training=True):
    4. #找两个特征 x1 x2
    5. x1s = np.linspace(axes[0],axes[1],100)
    6. x2s = np.linspace(axes[2],axes[3],100)
    7. #构建棋盘
    8. x1,x2 = np.meshgrid(x1s,x2s)
    9. #在棋盘中构建待测试数据
    10. X_new = np.c_[x1.ravel(),x2.ravel()]
    11. #将预测值算出来
    12. y_pred = clf.predict(X_new).reshape(x1.shape)
    13. #选择颜色
    14. custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    15. #绘制并填充不同的区域
    16. plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)
    17. if not iris:
    18. custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
    19. plt.contourf(x1,x2,y_pred,alpha=0.8,cmap=custom_cmap2)
    20. #可以把训练数据展示出来
    21. if plot_training:
    22. plt.plot(X[:,0][y==0],X[:,1][y==0],'yo',label='Iris-Setosa')
    23. plt.plot(X[:,0][y==1],X[:,1][y==1],'bs',label='Iris-Versicolor')
    24. plt.plot(X[:,0][y==2],X[:,1][y==2],'g^',label='Iris-Virginica')
    25. if iris:
    26. plt.xlabel('Petal length',fontsize = 14)
    27. plt.ylabel('Petal width',fontsize = 14)
    28. else:
    29. plt.xlabel(r'$x_1$',fontsize=18)
    30. plt.ylabel(r'$x_2$',fontsize=18)
    31. if legend:
    32. plt.legend(loc='lower right',fontsize=14)
    33. plt.figure(figsize=(8,4))
    34. plot_decision_boundary(tree_clf,X,y)
    35. plt.plot([2.45,2.45],[0,3],'k-',linewidth=2)
    36. plt.plot([2.45,7.5],[1.75,1.75],'k--',linewidth=2)
    37. plt.plot([4.95,4.95],[0,1.75],'k:',linewidth=2)
    38. plt.plot([4.85,4.85],[1.75,3],'k:',linewidth=2)
    39. plt.text(1.40,1.0,'Depth=0',fontsize=15)
    40. plt.text(3.2,1.80,'Depth=1',fontsize=13)
    41. plt.text(4.05,0.5,'(Depth=2)',fontsize=11)
    42. plt.title('Decision Tree decision boundareies')
    43. plt.show()

            可以看出三种不同颜色的代表分类结果,Depth=0可看作第一刀切分,Depth=1,2 看作第二刀,三刀,把数据集切分。

            2.5 决策树的正则化(剪枝)

            决策树的正则化

            DecisionTreeClassifier类还具有一些其他的参数类似地限制了决策树的形状

            min-samples_split(节点在分割之前必须具有的样本数)

            min-samples_leaf(叶子节点必须具有的最小样本数)

            max-leaf_nodes(叶子节点的最大数量)

            max_features(在每个节点处评估用于拆分的最大特征数)

            max_depth(树的最大深度)

    1. #6.决策树正则化
    2. from sklearn.datasets import make_moons
    3. X,y = make_moons(n_samples=100,noise=0.25,random_state=53)
    4. plt.plot(X[:,0],X[:,1],"b.")
    5. tree_clf1 = DecisionTreeClassifier(random_state=42)
    6. tree_clf2 = DecisionTreeClassifier(random_state=42,min_samples_leaf=4)
    7. tree_clf1.fit(X,y)
    8. tree_clf2.fit(X,y)
    9. plt.figure(figsize=(12,4))
    10. plt.subplot(121)
    11. plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
    12. plt.title('no restriction')
    13. plt.subplot(122)
    14. plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
    15. plt.title('min_samples_leaf={}'.format(tree_clf2.min_samples_leaf))

            可以看出在没有加限制条件之前,分类器要考虑每个点,模型变得复杂,容易过拟合。其他的一些参数读者可以自行尝试。

            2.6 对数据敏感

            决策树对于数据是很敏感的

            

    1. #6.对数据敏感
    2. np.random.seed(6)
    3. Xs = np.random.rand(100,2) - 0.5
    4. ys = (Xs[:,0] > 0).astype(np.float32) * 2
    5. angle = np.pi /4
    6. rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
    7. Xsr = Xs.dot(rotation_matrix)
    8. tree_clf_s = DecisionTreeClassifier(random_state=42)
    9. tree_clf_sr = DecisionTreeClassifier(random_state=42)
    10. tree_clf_s.fit(Xs,ys)
    11. tree_clf_sr.fit(Xsr,ys)
    12. plt.figure(figsize=(11,4))
    13. plt.subplot(121)
    14. plot_decision_boundary(tree_clf_s,Xs,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
    15. plt.title('Sensitivity to training set rotation')
    16. plt.subplot(122)
    17. plot_decision_boundary(tree_clf_sr,Xsr,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
    18. plt.title('Sensitivity to training set rotation')
    19. plt.show()

             这里是把数据又旋转了45度,然而决策边界并没有也旋转45度,却是变复杂了。可以看出,对于复杂的数据,决策树是很敏感的。

            2.7 回归任务

    1. #7.回归任务
    2. np.random.seed(42)
    3. m = 200
    4. X = np.random.rand(m,1)
    5. y = 4 * (X-0.5)**2
    6. y = y + np.random.randn(m,1) /10
    7. plt.plot(X,y,'b.')
    8. from sklearn.tree import DecisionTreeRegressor
    9. tree_reg = DecisionTreeRegressor(max_depth=2)
    10. tree_reg.fit(X,y)
    11. from sklearn.tree import export_graphviz
    12. export_graphviz(
    13. tree_reg,
    14. out_file='regression_tree.dot',
    15. feature_names=['X1'],
    16. rounded=True,
    17. filled=True
    18. )
    19. from IPython.display import Image
    20. Image(filename='regression_tree.png',width=400,height=400)

     

             回归任务,这里的衡量标准就变成了均方误差。

            2.8 对比树的深度对结果的影响

    1. #8.对比树的深度对结果的影响
    2. from sklearn.tree import DecisionTreeRegressor
    3. tree_reg1 = DecisionTreeRegressor(random_state=42,max_depth=2)
    4. tree_reg2 = DecisionTreeRegressor(random_state=42,max_depth=3)
    5. tree_reg1.fit(X,y)
    6. tree_reg2.fit(X,y)
    7. def plot_regression_predictions(tree_reg,X,y,axes=[0,1,-0.2,1],ylabel='$y$'):
    8. x1 = np.linspace(axes[0],axes[1],500).reshape(-1,1)
    9. y_pred = tree_reg.predict(x1)
    10. plt.axis(axes)
    11. plt.xlabel('$X_1$',fontsize =18)
    12. if ylabel:
    13. plt.ylabel(ylabel,fontsize = 18,rotation=0)
    14. plt.plot(X,y,'b.')
    15. plt.plot(x1,y_pred,'r.-',linewidth=2,label=r'$\hat{y}$')
    16. plt.figure(figsize=(11,4))
    17. plt.subplot(121)
    18. plot_regression_predictions(tree_reg1,X,y)
    19. for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    20. plt.plot([split,split],[-0.2,1],style,linewidth = 2)
    21. plt.text(0.21,0.65,'Depth=0',fontsize= 15)
    22. plt.text(0.01,0.2,'Depth=1',fontsize= 13)
    23. plt.text(0.65,0.8,'Depth=0',fontsize= 13)
    24. plt.legend(loc='upper center',fontsize = 18)
    25. plt.title('max_depth=2',fontsize=14)
    26. plt.subplot(122)
    27. plot_regression_predictions(tree_reg2,X,y)
    28. for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    29. plt.plot([split,split],[-0.2,1],style,linewidth = 2)
    30. for split in (0.0458,0.1298,0.2873,0.9040):
    31. plt.plot([split,split],[-0.2,1],linewidth = 1)
    32. plt.text(0.3,0.5,'Depth=2',fontsize= 13)
    33. plt.title('max_depth=3',fontsize=14)
    34. plt.show()

            不同的树的深度,对于结果产生极大的影响

            2.9 剪枝

            

    1. #9.加一些限制
    2. tree_reg1 = DecisionTreeRegressor(random_state=42)
    3. tree_reg2 = DecisionTreeRegressor(random_state=42,min_samples_leaf=10)
    4. tree_reg1.fit(X,y)
    5. tree_reg2.fit(X,y)
    6. x1 = np.linspace(0,1,500).reshape(-1,1)
    7. y_pred1 = tree_reg1.predict(x1)
    8. y_pred2 = tree_reg2.predict(x1)
    9. plt.figure(figsize=(11,4))
    10. plt.subplot(121)
    11. plt.plot(X,y,'b.')
    12. plt.plot(x1,y_pred1,'r.-',linewidth=2,label=r'$\hat{y}$')
    13. plt.axis([0,1,-0.2,1.1])
    14. plt.xlabel('$x_1$',fontsize=18)
    15. plt.ylabel('$y$',fontsize=18,rotation=0)
    16. plt.legend(loc='upper center',fontsize =18)
    17. plt.title('No restrctions',fontsize =14)
    18. plt.subplot(122)
    19. plt.plot(X,y,'b.')
    20. plt.plot(x1,y_pred2,'r.-',linewidth=2,label=r'$\hat{y}$')
    21. plt.axis([0,1,-0.2,1.1])
    22. plt.xlabel('$x_1$',fontsize=18)
    23. plt.ylabel('$y$',fontsize=18,rotation=0)
    24. plt.legend(loc='upper center',fontsize =18)
    25. plt.title('min_samples_leaf={}'.format(tree_reg2.min_samples_leaf),fontsize =14)
    26. plt.show()

            一目了然。 

  • 相关阅读:
    一个基于.Net高性能跨平台内网穿透工具
    嵌入式Linux驱动开发(I2C专题)(五)
    护眼灯真的可以保护眼睛吗?2022买什么护眼灯不伤孩子眼睛
    pytest学习和使用6-fixture如何使用?
    Java 查漏补缺
    Linux 常用运维使用指令
    Docker搭建Redis集群
    Java(笔试面试准备)
    Flutter基础 -- Dart 语言 -- 基础类型
    iOS UITableView获取到的contentSize不正确
  • 原文地址:https://blog.csdn.net/JamesSwifte/article/details/136403433