• mlr3实现二分类资料多个模型评价和比较


    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    前面介绍了使用tidymodels进行二分类资料的模型评价和比较,不知道大家学会了没?

    我之前详细介绍过mlr3这个包,也是目前R语言机器学习领域比较火的R包了,今天说下这么用mlr3进行二分类资料的模型评价和比较。

    本期目录:

    加载R包

    首先还是加载数据和R包,和之前的数据一样的。

    library(mlr3verse)
    ## Loading required package: mlr3
    library(mlr3pipelines)
    library(mlr3filters)
    • 1

    建立任务

    然后是对数据进行划分训练集和测试集,对数据进行预处理,为了和之前的tidymodels进行比较,这里使用的数据和预处理步骤都是和之前一样的。

    # 读取数据
    all_plays <- readRDS("../000files/all_plays.rds")

    # 建立任务
    pbp_task <- as_task_classif(all_plays, target="play_type")

    # 数据划分
    split_task <- partition(pbp_task, ratio=0.75)

    task_train <- pbp_task$clone()$filter(split_task$train)
    task_test <- pbp_task$clone()$filter(split_task$test)
    • 1

    数据预处理

    建立任务后就是建立数据预处理步骤,这里采用和上篇推文tidymodels中一样的预处理步骤:

    # 数据预处理
    pbp_prep <- po("select"# 去掉3列
                   selector = selector_invert(
                     selector_name(c("half_seconds_remaining","yards_gained","game_id")))
                   ) %>>%
      po("colapply"# 把这两列变成因子类型
         affect_columns = selector_name(c("posteam","defteam")),
         applicator = as.factor) %>>% 
      po("filter"# 去除高度相关的列
         filter = mlr3filters::flt("find_correlation"), filter.cutoff=0.3) %>>%
      po("scale", scale = F) %>>% # 中心化
      po("removeconstants"# 去掉零方差变量
    • 1

    可以看到mlr3的数据预处理与tidymodels相比,在语法上确实是有些复杂了,而且由于使用的R6,很多语法看起来很别扭,文档也说的不清楚,对于新手来说还是tidymodels更好些。目前来说最大的优势可能就是速度了吧。。。

    如果你想把预处理步骤应用于数据,得到预处理之后的数据,可以用以下代码:

    task_prep <- pbp_prep$clone()$train(pbp_task)[[1]]
    dim(task_train$data())
    ##  68982    26

    task_prep$feature_types
    ##                             id    type
    ##  1:                    defteam  factor
    ##  2:              defteam_score numeric
    ##  3: defteam_timeouts_remaining  factor
    ##  4:                       down ordered
    ##  5:                 goal_to_go  factor
    ##  6:                in_fg_range  factor
    ##  7:                in_red_zone  factor
    ##  8:                  no_huddle  factor
    ##  9:                    posteam  factor
    ## 10:              posteam_score numeric
    ## 11: posteam_timeouts_remaining  factor
    ## 12:              previous_play  factor
    ## 13:                        qtr ordered
    ## 14:         score_differential numeric
    ## 15:                    shotgun  factor
    ## 16:                 total_pass numeric
    ## 17:              two_min_drill  factor
    ## 18:               yardline_100 numeric
    ## 19:                    ydstogo numeric
    • 1

    这样就得到了处理好的数据,但是对于mlr3pipelines来说,这一步做不做都可以。

    选择多个模型

    还是选择和之前一样的4个模型:逻辑回归、随机森林、决策树、k最近邻:

    # 随机森林
    rf_glr <- as_learner(pbp_prep %>>% lrn("classif.ranger", predict_type="prob")) 
    rf_glr$id <- "randomForest"
    
    # 逻辑回归
    log_glr <-as_learner(pbp_prep %>>% lrn("classif.log_reg", predict_type="prob")) 
    log_glr$id <- "logistic"
    
    # 决策树
    tree_glr <- as_learner(pbp_prep %>>% lrn("classif.rpart", predict_type="prob")) 
    tree_glr$id <- "decisionTree"
    
    # k近邻
    kknn_glr <- as_learner(pbp_prep %>>% lrn("classif.kknn", predict_type="prob")) 
    kknn_glr$id <- "kknn"
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    建立benchmark_grid

    类似于tidymodels中的workflow_set

    接下来就是选择10折交叉验证,建立多个模型,语法也是很简单了。

    set.seed(0520)

    # 10折交叉验证
    cv <- rsmp("cv",folds=10)

    set.seed(0520)

    # 建立多个模型
    design <- benchmark_grid(
      tasks = task_train,
      learners = list(rf_glr,log_glr,tree_glr,kknn_glr),
      resampling = cv
    )
    • 1

    在训练集中,使用10折交叉验证,运行4个模型,看这语法是不是也很简单清晰?

    开始计算

    下面就是开始计算,和tidymodels相比,这一块语法更加简单一点,就是建立benchmark_grid,然后使用benchmark()函数即可。

    # 加速
    library(future)
    plan("multisession",workers=12)

    # 减少屏幕输出
    lgr::get_logger("mlr3")$set_threshold("warn")
    lgr::get_logger("bbotk")$set_threshold("warn")

    # 开始运行
    bmr <- benchmark(design,store_models = T)

    Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 14 seconds.
    Growing trees.. Progress: 61%. Estimated remaining time: 39 seconds.
    Growing trees.. Progress: 92%. Estimated remaining time: 8 seconds.
    Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 16 seconds.
    Growing trees.. Progress: 60%. Estimated remaining time: 40 seconds.
    Growing trees.. Progress: 91%. Estimated remaining time: 8 seconds.
    Growing trees.. Progress: 43%. Estimated remaining time: 40 seconds.
    Growing trees.. Progress: 83%. Estimated remaining time: 12 seconds.
    Growing trees.. Progress: 42%. Estimated remaining time: 42 seconds.
    Growing trees.. Progress: 90%. Estimated remaining time: 7 seconds.
    Growing trees.. Progress: 30%. Estimated remaining time: 1 minute, 10 seconds.
    Growing trees.. Progress: 62%. Estimated remaining time: 38 seconds.
    Growing trees.. Progress: 93%. Estimated remaining time: 7 seconds.
    Growing trees.. Progress: 30%. Estimated remaining time: 1 minute, 10 seconds.
    Growing trees.. Progress: 61%. Estimated remaining time: 38 seconds.
    Growing trees.. Progress: 92%. Estimated remaining time: 7 seconds.
    Growing trees.. Progress: 29%. Estimated remaining time: 1 minute, 15 seconds.
    Growing trees.. Progress: 60%. Estimated remaining time: 41 seconds.
    Growing trees.. Progress: 91%. Estimated remaining time: 9 seconds.
    Growing trees.. Progress: 32%. Estimated remaining time: 1 minute, 7 seconds.
    Growing trees.. Progress: 73%. Estimated remaining time: 22 seconds.
    Growing trees.. Progress: 42%. Estimated remaining time: 42 seconds.
    Growing trees.. Progress: 84%. Estimated remaining time: 11 seconds.
    Growing trees.. Progress: 32%. Estimated remaining time: 1 minute, 7 seconds.
    Growing trees.. Progress: 63%. Estimated remaining time: 36 seconds.
    Growing trees.. Progress: 94%. Estimated remaining time: 6 seconds.

    # 结果
    bmr

     of 40 rows with 4 resampling runs
     nr   task_id   learner_id resampling_id iters warnings errors
      1 all_plays randomForest            cv    10        0      0
      2 all_plays     logistic            cv    10        0      0
      3 all_plays decisionTree            cv    10        0      0
      4 all_plays         kknn            cv    10        0      0
    • 1

    查看模型表现

    查看结果:

    # 默认结果
    bmr$aggregate()

    nr      resample_result   task_id   learner_id resampling_id iters classif.ce
    1:  1 22]> all_plays randomForest            cv    10  0.2695630
    2:  2 22]> all_plays     logistic            cv    10  0.2770287
    3:  3 22]> all_plays decisionTree            cv    10  0.2799570
    4:  4 22]> all_plays         kknn            cv    10  0.3220549
    • 1

    也是支持同时查看多个结果的:

    measures <- msrs(c("classif.auc","classif.acc","classif.bbrier"))

    bmr_res <- bmr$aggregate(measures)
    bmr_res[,c(4,7:9)]

       learner_id classif.auc classif.acc classif.bbrier
    1: randomForest   0.7978436   0.7304370      0.1790968
    2:     logistic   0.7798504   0.7229713      0.1866577
    3: decisionTree   0.7034790   0.7200430      0.2003303
    4:         kknn   0.7322762   0.6779451      0.2210171
    • 1

    结果可视化

    支持ggplot2语法,使用起来和tidymodels差不多,也是对结果直接autoplot()即可。

    library(ggplot2)
    autoplot(bmr)+theme(axis.text.x = element_text(angle = 45))
    • 1
    alt

    喜闻乐见的ROC曲线:

    autoplot(bmr,type = "roc")
    • 1
    alt

    选择最好的模型

    通过比较结果可以发现还是随机森林效果最好~,下面选择随机森林,在训练集上训练,在测试集上测试结果。

    这一步并没有使用10折交叉验证,如果你想用,也是可以的~

    # 训练
    rf_glr$train(task_train)
    
    • 1
    • 2
    • 3

    训练好之后就是在测试集上测试并查看结果:

    # 测试
    prediction <- rf_glr$predict(task_test)
    head(as.data.table(prediction))
    
    row_ids truth response prob.pass   prob.run
    1:       4   run     pass 0.7649998 0.23500021
    2:       6   run      run 0.4168520 0.58314804
    3:      11  pass     pass 0.7199717 0.28002834
    4:      13   run     pass 0.9406333 0.05936668
    5:      17   run      run 0.4073665 0.59263354
    6:      24  pass     pass 0.6243693 0.37563072
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    混淆矩阵:

    prediction$confusion

            truth
    response  pass   run
        pass 10629  3175
        run   2955  6235
    • 1

    可视化混淆矩阵:

    autoplot(prediction)
    • 1
    alt

    当然也是支持多个指标的:

    prediction$score(msrs(c("classif.auc","classif.acc","classif.bbrier")))

    classif.auc    classif.acc classif.bbrier 
    0.8011720      0.7334087      0.1775684 
    • 1

    喜闻乐见ROC曲线:

    autoplot(prediction,type = "roc")
    • 1
    image-20220704162604466
    image-20220704162604466

    总体来看mlr3tidymodels相比有优势也有劣势,基本步骤大同小异,除了预处理步骤比较复杂外,其他地方都比较简单~

    初学者还是推荐使用tidymodels,熟悉了可以试一下mlr3,集成化程度更高,目前也更加稳定,tidymodels目前还处于快速开发中,经常出现各种小问题,但是说明文档比较详细。

    mlr3相比之下更稳定一些,速度明显更快!尤其是数据量比较大的时候!但是mlr3的说明文档并不是很详细,只有mlr3 book,而且很多用法并没有介绍!经常得自己琢磨。

    mlr3 book中文翻译版 可以翻看我之前的推文!

    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    本文由 mdnice 多平台发布

  • 相关阅读:
    window系统安装 NodeJS
    什么时候用 C 而不用 C++?
    字符串变形
    全域智慧采摘无人机系统探索
    Open3D 进阶(17)间接平差拟合二维直线
    【信号处理】基于优化算法的 SAR 信号处理(Matlab代码实现)
    性能测试 —— Jmeter分布式测试的注意事项和常见问题
    【微信小程序】image组件的4种缩放模式与9种裁剪模式
    Handler同步屏障学习
    go 指针
  • 原文地址:https://blog.csdn.net/Ayue0616/article/details/126869645