• workflow一次完成多个模型评价和比较


    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    前面给大家介绍了使用tidymodels搞定二分类资料的模型评价和比较。

    简介的语法、统一的格式、优雅的操作,让人欲罢不能!

    但是太费事儿了,同样的流程来了4遍,那要是选择10个模型,就得来10遍!无聊,非常的无聊。

    所以个大家介绍简便方法,不用重复写代码,一次搞定多个模型!

    本期目录:

    加载数据和R包

    首先还是加载数据和R包,和前面的一模一样的操作,数据也没变。

    suppressPackageStartupMessages(library(tidyverse))
    suppressPackageStartupMessages(library(tidymodels))
    library(kknn)
    tidymodels_prefer()

    all_plays <- read_rds("../000files/all_plays.rds")

    set.seed(20220520)

    split_pbp <- initial_split(all_plays, 0.75, strata = play_type)

    train_data <- training(split_pbp)
    test_data <- testing(split_pbp)
    • 1

    数据预处理

    pbp_rec <- recipe(play_type ~ ., data = train_data)  %>%
      step_rm(half_seconds_remaining,yards_gained, game_id) %>% 
      step_string2factor(posteam, defteam) %>%  
      step_corr(all_numeric(), threshold = 0.7) %>% 
      step_center(all_numeric()) %>%  
      step_zv(all_predictors())  
    • 1

    选择模型

    直接选择4个模型,你想选几个都是可以的。

    lm_mod <- logistic_reg(mode = "classification",engine = "glm")
    knn_mod <- nearest_neighbor(mode = "classification", engine = "kknn")
    rf_mod <- rand_forest(mode = "classification", engine = "ranger")
    tree_mod <- decision_tree(mode = "classification",engine = "rpart")
    • 1

    选择重抽样方法

    set.seed(20220520)

    folds <- vfold_cv(train_data, v = 10)
    folds
    ## #  10-fold cross-validation 
    ## # A tibble: 10 × 2
    ##    splits               id    
    ##                    
    ##  1  Fold01
    ##  2  Fold02
    ##  3  Fold03
    ##  4  Fold04
    ##  5  Fold05
    ##  6  Fold06
    ##  7  Fold07
    ##  8  Fold08
    ##  9  Fold09
    ## 10  Fold10
    • 1

    构建workflow

    这一步就是不用重复写代码的关键,把所有模型和数据预处理步骤自动连接起来。

    library(workflowsets)

    four_mods <- workflow_set(list(rec = pbp_rec), 
                              list(lm = lm_mod,
                                   knn = knn_mod,
                                   rf = rf_mod,
                                   tree = tree_mod
                                   ),
                              cross = T
                              )
    four_mods
    ## # A workflow set/tibble: 4 × 4
    ##   wflow_id info             option    result    
    ##                          
    ## 1 rec_lm     
    ## 2 rec_knn    
    ## 3 rec_rf     
    ## 4 rec_tree   
    • 1

    运行模型

    首先是一些运行过程中的参数设置:

    keep_pred <- control_resamples(save_pred = T, verbose = T)
    • 1

    然后就是运行4个模型(目前一直是在训练集中),我们给它加速一下:

    library(doParallel) 
    ## Loading required package: foreach
    ## 
    ## Attaching package: 'foreach'
    ## The following objects are masked from 'package:purrr':
    ## 
    ##     accumulate, when
    ## Loading required package: iterators
    ## Loading required package: parallel

    cl <- makePSOCKcluster(12# 加速,用12个线程
    registerDoParallel(cl)

    four_fits <- four_mods %>% 
      workflow_map("fit_resamples",
                   seed = 0520,
                   verbose = T,
                   resamples = folds,
                   control = keep_pred
                   )
    ## i 1 of 4 resampling: rec_lm
    ## ✔ 1 of 4 resampling: rec_lm (18.4s)
    ## i 2 of 4 resampling: rec_knn
    ## ✔ 2 of 4 resampling: rec_knn (3m 51.9s)
    ## i 3 of 4 resampling: rec_rf
    ## ✔ 3 of 4 resampling: rec_rf (1m 15.6s)
    ## i 4 of 4 resampling: rec_tree
    ## ✔ 4 of 4 resampling: rec_tree (6.1s)

    four_fits
    ## # A workflow set/tibble: 4 × 4
    ##   wflow_id info             option    result   
    ##                         
    ## 1 rec_lm     
    ## 2 rec_knn    
    ## 3 rec_rf     
    ## 4 rec_tree   

    stopCluster(cl)
    • 1

    需要很长时间!大家笔记本如果内存不够可能会失败哦~

    查看结果

    查看模型在训练集中的表现:

    collect_metrics(four_fits)
    ## # A tibble: 8 × 9
    ##   wflow_id .config          preproc model .metric .estimator  mean     n std_err
    ##                                    
    ## 1 rec_lm   Preprocessor1_M… recipe  logi… accura… binary     0.724    10 1.91e-3
    ## 2 rec_lm   Preprocessor1_M… recipe  logi… roc_auc binary     0.781    10 1.88e-3
    ## 3 rec_knn  Preprocessor1_M… recipe  near… accura… binary     0.671    10 7.31e-4
    ## 4 rec_knn  Preprocessor1_M… recipe  near… roc_auc binary     0.716    10 1.28e-3
    ## 5 rec_rf   Preprocessor1_M… recipe  rand… accura… binary     0.732    10 1.48e-3
    ## 6 rec_rf   Preprocessor1_M… recipe  rand… roc_auc binary     0.799    10 1.90e-3
    ## 7 rec_tree Preprocessor1_M… recipe  deci… accura… binary     0.720    10 1.97e-3
    ## 8 rec_tree Preprocessor1_M… recipe  deci… roc_auc binary     0.704    10 2.01e-3
    • 1

    查看每一个预测结果,这个就不运行了,毕竟好几万行,太多了。。。

    collect_predictions(four_fits)
    • 1

    可视化结果

    直接可视化4个模型的结果,感觉比ROC曲线更好看,还给出了可信区间。

    这个图可以自己用ggplot2语法修改。

    four_fits %>% autoplot(metric = "roc_auc")+theme_bw()
    • 1
    image-20220704145235120
    image-20220704145235120

    选择最好的模型用于测试集

    选择表现最好的应用于测试集:

    rand_res <- last_fit(rf_mod,pbp_rec,split_pbp)
    • 1

    查看在测试集的模型表现:

    collect_metrics(rand_res) # test 中的模型表现
    • 1
    image-20220704144956748
    image-20220704144956748

    使用其他指标查看模型表现:

    metricsets <- metric_set(accuracy, mcc, f_meas, j_index)

    collect_predictions(rand_res) %>% 
      metricsets(truth = play_type, estimate = .pred_class)
    • 1
    image-20220704145017664
    image-20220704145017664

    可视化结果,喜闻乐见的混淆矩阵:

    collect_predictions(rand_res) %>% 
      conf_mat(play_type,.pred_class) %>% 
      autoplot()
    • 1
    image-20220704145028522
    image-20220704145028522

    喜闻乐见的ROC曲线:

    collect_predictions(rand_res) %>% 
      roc_curve(play_type,.pred_pass) %>% 
      autoplot()
    • 1
    image-20220704145041578
    image-20220704145041578

    还有非常多曲线和评价指标可选,大家可以看我之前的介绍推文~

    是不是很神奇呢,完美符合一次挑选多个模型的要求,且步骤清稀,代码美观,非常适合进行多个模型的比较。

    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    本文由 mdnice 多平台发布

  • 相关阅读:
    ASP.NET Core3.1 API 创建(Swagger配置、数据库连接Sql Server)、开发、部署
    img标签如何将<svg></svg>数据渲染出来
    获取iOS和Android的app下载渠道和相关参数的方式
    Qgis根据区域划分点、线面
    什么是 DeGods NFT 系列?
    建模助手:Revit中捕捉点设置问题和楼层排序设置
    二维数组与稀疏数组转换(java)
    相机图像质量研究(32)常见问题总结:图像处理对成像的影响--振铃效应
    Android sdk工程搭建(aar)
    【JAVA并发】AQS原理详解
  • 原文地址:https://blog.csdn.net/Ayue0616/article/details/126869606