• SparkMlib 之决策树及其案例


    什么是决策树

    决策树及其集成是分类和回归机器学习任务的流行方法。决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征相互作用。随机森林和增强算法等树集成算法在分类和回归任务中表现最佳。

    常应用于以下类型的场景:

    1. 预测用户贷款是否能够按时还款;
    2. 预测邮件是否是垃圾邮件;
    3. 预测用户是否会购买某件商品等等

    官网:分类和回归

    决策树的优缺点

    优点:

    1. 决策树算法易理解,机理解释起来简单。

    2. 决策树算法可以用于小数据集。

    3. 决策树算法的时间复杂度较小,为用于训练决策树的数据点的对数。

    4. 相比于其他算法智能分析一种类型变量,决策树算法可处理数字和数据的类别。

    5. 能够处理多输出的问题。

    6. 对缺失值不敏感。

    7. 可以处理不相关特征数据。

    8. 效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。

    缺点:

    1. 对连续性的字段比较难预测。

    2. 容易出现过拟合。

    3. 当类别太多时,错误可能就会增加的比较快。

    4. 在处理特征关联性比较强的数据时表现得不是太好。

    5. 对于各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征。

    参考博客:决策树算法优缺点

    决策树示例——鸢尾花分类

    数据集下载:

    链接:
    https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l 
    
    提取码:
    lz3l
    
    • 1
    • 2
    • 3
    • 4
    • 5

    数据集介绍:

    iris.data 数据集中共有五个字段,逗号分隔,前四个为特征字段,最后一个为标签字段。

    标签字段列一共有三种值,分别是:Iris-setosaIris-versicolorIris-virginica

    将数据集中的随机百分之70作为训练集,剩余的作为测试集。

    需求实现:

    import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    import org.apache.spark.ml.feature.LabeledPoint
    import org.apache.spark.ml.linalg.Vectors
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    
    object Iris {
    
        // TODO 鸢尾花种类判断
    
        def main(args: Array[String]): Unit = {
    
            val sc: SparkSession = SparkSession
                    .builder()
                    .appName("Iris")
                    .master("local[*]")
                    .getOrCreate()
    
            // 1.加载鸢尾花数据
            val train_data: RDD[String] = sc
                    .read
                    .textFile("iris.data")
                    .rdd
    
            // 2.将随机百分之70的数据设置为训练集,其余为测试集
            val data: Array[RDD[String]] = train_data.randomSplit(Array(0.7, 0.3))
    
            // 3.向量转换
            import sc.implicits._
    
            val trainDF: DataFrame = data(0).map(lines => {
                val arr: Array[String] = lines.split(",")
                LabeledPoint(
                    if (arr(4).equals("Iris-setosa")) {
                        1D
                    } else if (arr(4).equals("Iris-versicolor")) {
                        2D
                    } else {
                        3D
                    },
                    Vectors.dense(arr.take(4).map(_.toDouble))
                )
            }).toDF("label", "features")
    
            // 4.创建决策树对象
            val classifier = new DecisionTreeClassifier()
    
            // 设置最大深度、分支、质量、特征列
            classifier.setMaxDepth(5).setMaxBins(32).setImpurity("gini").setFeaturesCol("features")
    
            // 5.训练模型
            val model: DecisionTreeClassificationModel = classifier.fit(trainDF)
    
            // 打印模型
            println(model.toDebugString)
    
            // 6.将测试集转换成向量
            val testDF: DataFrame = data(1).map(lines => {
                val arr: Array[String] = lines.split(",")
                LabeledPoint(
                    if (arr(4).equals("Iris-setosa")) {
                        1D
                    } else if (arr(4).equals("Iris-versicolor")) {
                        2D
                    } else {
                        3D
                    },
                    Vectors.dense(arr.take(4).map(_.toDouble))
                )
            }).toDF("label", "features")
    
            // 7.模型预测
            val result: DataFrame = model.transform(testDF.select("label", "features"))
    
            // 8.模型预测评估
            result.select("label", "features","prediction").show(100)
    
            // 9.计算错误率
            val error: Double = result.where("label = prediction").count.toDouble/result.count
            println("错误率为:"+(1-error))
    
        }
    
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
  • 相关阅读:
    2022年最新西藏水利水电施工安全员考试题库及答案
    爆肝Python自学学习路线
    前端JS必用工具【js-tool-big-box】学习,打开全屏和关闭全屏
    C专家编程 第11章 你懂得C,所以C++不再话下 11.2 抽象---取事物的本质特性
    制作游戏拼图游戏
    Android recycleview瀑布流中间穿插一行占满一屏
    vm的生命周期钩子
    AI硬件:显卡 vs. 处理器 vs. 量子计算机
    shiro的配置详解
    Cesium之Web Workers
  • 原文地址:https://blog.csdn.net/weixin_46389691/article/details/128103174