• sklearn决策树(Decision Trees)模型


    决策树(DT)是一种用于分类和回归的非参数化监督学习方法。其目的是创建一个模型,通过学习从数据特征推断出的简单决策规则来预测目标变量的值。一棵树可以被看作是一个分片的常数近似。

    决策树的一些优点:

    • 易于理解和解释。树可以被视觉化。
    • 需要很少的数据准备。其他技术通常需要数据规范化,需要创建虚拟变量并删除空白值。但是请注意,这个模块不支持缺失值。
    • 使用树的成本(即预测数据)与用于训练树的数据点的数量成对数关系。
    • 能够处理数字和分类数据。但是scikit-learn的实现暂时不支持分类变量。其他技术通常专门用于分析只有一种类型变量的数据集。更多信息请参见算法。
    • 能够处理多输出问题。
    • 使用白盒模型。如果一个给定的情况在模型中是可以观察到的,那么对该情况的解释就很容易用布尔逻辑来解释。相比之下,在一个黑箱模型中(如在人工神经网络中),结果可能更难解释。
    • 有可能使用统计测试来验证一个模型。这使得核算模型的可靠性成为可能。
    • 即使其假设在某种程度上违反了数据产生的真实模型,也能表现良好。

    决策树的缺点:

    • 决策树学习者可以创建过于复杂的树,不能很好地概括数据。这就是所谓的过度拟合。诸如修剪、设置叶子节点所需的最小样本数或设置树的最大深度等机制对于避免这一问题是必要的。
    • 决策树可能是不稳定的,因为数据的微小变化可能会导致生成一个完全不同的树。这个问题可以通过在一个集合体中使用决策树而得到缓解。
    • 决策树的预测既不是平滑的,也不是连续的,而是如上图所示的片状常数近似值。因此,它们不善于推断。
    • 众所周知,学习最优决策树的问题在几个方面的最优性下是NP-complete的,甚至对于简单的概念也是如此。因此,实用的决策树学习算法是基于启发式算法,如贪婪算法,在每个节点上做出局部最优决策。这种算法不能保证返回全局最优的决策树。这一点可以通过在集合学习器中训练多棵树来缓解,在集合学习器中,特征和样本都是随机抽样的,并有替换。
    • 有一些概念很难学习,因为决策树不容易表达,比如XOR、奇偶性或多路复用器问题。
    • 如果某些类占主导地位,决策树学习者会创建有偏见的树。因此,建议在用决策树拟合之前,平衡数据集。

    1. 分类问题

    1. ### 1. classification
    2. from sklearn.datasets import load_iris
    3. from sklearn import tree
    4. from sklearn.tree import export_text
    5. from sklearn.model_selection import train_test_split
    6. iris = load_iris()
    7. X, y = iris.data, iris.target
    8. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
    9. clf = tree.DecisionTreeClassifier(random_state=0, max_depth=2)
    10. clf = clf.fit(X_train, y_train)
    11. r = export_text(clf, feature_names=iris['feature_names'])
    12. print(r)
    13. print(clf.predict(X_test))
    14. #print(clf.predict_proba(X_test)) # probability of each class
    15. print(clf.score(X_test,y_test))
    16. # plot
    17. import matplotlib.pyplot as plt
    18. from sklearn.tree import plot_tree
    19. plt.figure()
    20. clf = tree.DecisionTreeClassifier().fit(iris.data, iris.target)
    21. plot_tree(clf, filled=True)
    22. plt.title("Decision tree trained on all the iris features")
    23. plt.show()

    2. 回归问题

    1. ### 2. regression
    2. # Import the necessary modules and libraries
    3. import numpy as np
    4. from sklearn.tree import DecisionTreeRegressor
    5. import matplotlib.pyplot as plt
    6. # Create a random dataset
    7. rng = np.random.RandomState(1)
    8. X = np.sort(5 * rng.rand(80, 1), axis=0)
    9. y = np.sin(X).ravel()
    10. y[::5] += 3 * (0.5 - rng.rand(16))
    11. # Fit regression model
    12. regr_1 = DecisionTreeRegressor(max_depth=2)
    13. regr_2 = DecisionTreeRegressor(max_depth=5)
    14. regr_1.fit(X, y)
    15. regr_2.fit(X, y)
    16. # Predict
    17. X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
    18. y_1 = regr_1.predict(X_test)
    19. y_2 = regr_2.predict(X_test)
    20. # Plot the results
    21. plt.figure()
    22. plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
    23. plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
    24. plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
    25. plt.xlabel("data")
    26. plt.ylabel("target")
    27. plt.title("Decision Tree Regression")
    28. plt.legend()
    29. plt.show()

    参考:

    https://scikit-learn.org/stable/modules/tree.html

  • 相关阅读:
    2023山东学生眼部健康展会/中国国际视力防控发展论坛
    作为程序员听过《元宇宙》,那你听过《元编程》吗?
    Leetcode 23.旋转排序数组
    Docker基础语法学习笔记
    双十二购买护眼台灯亮度多少合适?灯光亮度多少对眼睛比较好呢
    河南萌新联赛2024第(四)场:河南理工大学
    10 特征向量与特征值
    【C语言】深度剖析数据在内存中的存储
    Android App开发超实用实例 | ​Broadcast
    将GC编程语言引入WebAssembly的新方法
  • 原文地址:https://blog.csdn.net/qq_27390023/article/details/125877372