• tidymodels绘制校准曲线


    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    很多人都开始用tidymodels了,但是很多人还没意识到,tidymodels目前还不支持一键绘制校准曲线!相同类型的mlr3也是不支持的,都说在开发中!开发了1年多了,还没开发好!

    大家可以去项目的github相关的issue里面留言,引起开发者重视。。。

    总的来说,在临床预测模型这个领域,目前还是一些分散的R包更好用,尤其是涉及到时间依赖性的生存数据时,tidymodelsmlr3目前还无法满足大家的需求~

    但是很多朋友想要用这俩包画校准曲线曲线,其实还是可以搞一下的,挺简单的,之前介绍过很多次了,校准曲线就是散点图,横坐标是预测概率,纵坐标是实际概率(换过来也行!)。不理解的赶紧看这里:一文搞定临床预测模型评价

    今天先介绍下tidymodels的校准曲线画法,之前也介绍过:使用tidymodels完成多个模型评价和比较

    加载数据和R包

    没有安装的R包的自己安装下~

    suppressPackageStartupMessages(library(tidyverse))
    suppressPackageStartupMessages(library(tidymodels))
    tidymodels_prefer()
    • 1

    由于要做演示用,肯定要一份比较好的数据才能说明问题,今天用的这份数据,结果变量是一个二分类的。

    一共有91976行,26列,其中play_type是结果变量,因子型,其余列都是预测变量。

    all_plays <- read_rds("../000files/all_plays.rds")
    glimpse(all_plays)
    ## Rows: 91,976
    ## Columns: 26
    ## $ game_id                     2017090700, 2017090700, 2017090700, 2017090…
    ## $ posteam                     "NE", "NE", "NE", "NE", "NE", "NE", "NE", "…
    ## $ play_type                   pass, pass, run, run, pass, run, pass, pass…
    ## $ yards_gained                0, 8, 8, 3, 19, 5, 16, 0, 2, 7, 0, 3, 10, 0…
    ## $ ydstogo                     10, 10, 2, 10, 7, 10, 5, 2, 2, 10, 10, 10, …
    ## $ down                        1, 2, 3, 1, 2, 1, 2, 1, 2, 1, 1, 2, 3, 1, 2…
    ## $ game_seconds_remaining      3595, 3589, 3554, 3532, 3506, 3482, 3455, 3…
    ## $ yardline_100                73, 73, 65, 57, 54, 35, 30, 2, 2, 75, 32, 3…
    ## $ qtr                         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
    ## $ posteam_score               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 7, 7, 7, 7…
    ## $ defteam                     "KC", "KC", "KC", "KC", "KC", "KC", "KC", "…
    ## $ defteam_score               0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0…
    ## $ score_differential          0, 0, 0, 0, 0, 0, 0, 0, 0, -7, 7, 7, 7, 7, …
    ## $ shotgun                     0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0…
    ## $ no_huddle                   0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0…
    ## $ posteam_timeouts_remaining  3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3…
    ## $ defteam_timeouts_remaining  3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3…
    ## $ wp                          0.5060180, 0.4840546, 0.5100098, 0.5529816,…
    ## $ goal_to_go                  0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0…
    ## $ half_seconds_remaining      1795, 1789, 1754, 1732, 1706, 1682, 1655, 1…
    ## $ total_runs                  0, 0, 0, 1, 2, 2, 3, 3, 3, 0, 4, 4, 4, 5, 5…
    ## $ total_pass                  0, 1, 2, 2, 2, 3, 3, 4, 5, 0, 5, 6, 7, 7, 8…
    ## $ previous_play               First play of Drive, pass, pass, run, run, …
    ## $ in_red_zone                 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1…
    ## $ in_fg_range                 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1…
    ## $ two_min_drill               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
    • 1

    数据划分

    把75%的数据用于训练集,剩下的做测试集。

    set.seed(20220520)

    # 数据划分,根据play_type分层
    split_pbp <- initial_split(all_plays, 0.75, strata = play_type)

    train_data <- training(split_pbp) # 训练集
    test_data <- testing(split_pbp) # 测试集
    • 1

    数据预处理

    使用recipe包进行数据预处理,如果你认真学习过caret,那这个包你应该不陌生。

    pbp_rec <- recipe(play_type ~ ., data = train_data)  %>%
      step_rm(half_seconds_remaining,yards_gained, game_id) %>% # 移除这3列
      step_string2factor(posteam, defteam) %>%  # 变为因子类型
      #update_role(yards_gained, game_id, new_role = "ID") %>% 
      # 去掉高度相关的变量
      step_corr(all_numeric(), threshold = 0.7) %>% 
      step_center(all_numeric()) %>%  # 中心化
      step_zv(all_predictors())  # 去掉零方差变量
    • 1

    建立模型

    就以经常用的随机森林进行演示,这里就不演示调参了,因为也不一定比默认参数的结果好......

    选择随机森林,建立workflow

    rf_spec <- rand_forest(mode = "classification") %>% 
      set_engine("ranger",importance = "permutation")
    rf_wflow <- workflow() %>% 
      add_recipe(pbp_rec) %>% 
      add_model(rf_spec)
    • 1

    在训练集建模:

    fit_rf <- rf_wflow %>% 
      fit(train_data)
    • 1

    模型评价

    应用于测试集:

    pred_rf <- test_data %>% select(play_type) %>% 
      bind_cols(predict(fit_rf, test_data, type = "prob")) %>% 
      bind_cols(predict(fit_rf, test_data, type = "class"))
    • 1

    这个pred_rf就是接下来一系列操作的基础,非常重要!!

    head(pred_rf)
    ## # A tibble: 6 × 4
    ##   play_type .pred_pass .pred_run .pred_class
    ##                         
    ## 1 pass           0.312     0.688 run        
    ## 2 pass           0.829     0.171 pass       
    ## 3 pass           0.806     0.194 pass       
    ## 4 pass           0.678     0.322 pass       
    ## 5 run            0.184     0.816 run        
    ## 6 run            0.544     0.456 pass
    • 1

    查看模型表现:

    你知道的又或者不知道的指标基本上都有:

    metricsets <- metric_set(accuracy, mcc, f_meas, j_index)

    pred_rf %>% metricsets(truth = play_type, estimate = .pred_class)
    ## # A tibble: 4 × 3
    ##   .metric  .estimator .estimate
    ##                 
    ## 1 accuracy binary         0.731
    ## 2 mcc      binary         0.441
    ## 3 f_meas   binary         0.774
    ## 4 j_index  binary         0.439
    • 1

    混淆矩阵:

    pred_rf %>% conf_mat(truth = play_type, estimate = .pred_class)
    ##           Truth
    ## Prediction  pass   run
    ##       pass 10622  3226
    ##       run   2962  6185
    • 1

    混淆矩阵图形版:

    pred_rf %>% 
      conf_mat(play_type,.pred_class) %>% 
      autoplot()
    • 1
    plot of chunk unnamed-chunk-11
    plot of chunk unnamed-chunk-11

    大家最喜欢的AUC:

    pred_rf %>% roc_auc(truth = play_type, .pred_pass)
    ## # A tibble: 1 × 3
    ##   .metric .estimator .estimate
    ##                
    ## 1 roc_auc binary         0.799
    • 1

    可视化结果,首先是大家喜闻乐见的ROC曲线:

    pred_rf %>% roc_curve(truth = play_type, .pred_pass) %>% 
      autoplot()
    • 1
    plot of chunk unnamed-chunk-13
    plot of chunk unnamed-chunk-13

    pr曲线:

    pred_rf %>% pr_curve(truth = play_type, .pred_pass) %>% 
      autoplot()
    • 1
    plot of chunk unnamed-chunk-14
    plot of chunk unnamed-chunk-14

    gain_curve:

    pred_rf %>% gain_curve(truth = play_type, .pred_pass) %>% 
      autoplot()
    • 1
    plot of chunk unnamed-chunk-15
    plot of chunk unnamed-chunk-15

    lift_curve:

    pred_rf %>% lift_curve(truth = play_type, .pred_pass) %>% 
      autoplot()
    • 1
    plot of chunk unnamed-chunk-16
    plot of chunk unnamed-chunk-16

    就是没有校准曲线!!

    校准曲线

    下面给大家手动画一个校准曲线

    两种画法,差别不大,主要是分组方法不一样,第2种分组方法是大家常见的哦~

    如果你还不懂为什么我说校准曲线是散点图,建议你先看看一些基础知识:x一文搞定临床预测模型的评价,看了不吃亏。

    calibration_df <- pred_rf %>% 
       mutate(pass = if_else(play_type == "pass"10),
              pred_rnd = round(.pred_pass, 2)
              ) %>% 
      group_by(pred_rnd) %>% 
      summarize(mean_pred = mean(.pred_pass),
                mean_obs = mean(pass),
                n = n()
                )

    ggplot(calibration_df, aes(mean_pred, mean_obs))+ 
      geom_point(aes(size = n), alpha = 0.5)+
      geom_abline(linetype = "dashed")+
      theme_minimal()
    • 1
    plot of chunk unnamed-chunk-17
    plot of chunk unnamed-chunk-17

    第2种方法:

    cali_df <- pred_rf %>% 
      arrange(.pred_pass) %>% 
      mutate(pass = if_else(play_type == "pass"10),
             group = c(rep(1:249,each=92), rep(250,87))
             ) %>% 
      group_by(group) %>% 
      summarise(mean_pred = mean(.pred_pass),
                mean_obs = mean(pass)
                )


    cali_plot <- ggplot(cali_df, aes(mean_pred, mean_obs))+ 
      geom_point(alpha = 0.5)+
      geom_abline(linetype = "dashed")+
      theme_minimal()

    cali_plot
    • 1
    plot of chunk unnamed-chunk-18
    plot of chunk unnamed-chunk-18

    两种方法差别不大,效果都是很好的,这就说明,好就是好,不管你用什么方法,都是好!如果你的数据很烂,那大概率你的结果也是很烂!不管用什么方法都是烂!

    最后,随机森林这种方法是可以计算变量重要性的,当然也是能把结果可视化的。

    顺手给大家演示下如何可视化随机森林结果的变量重要性:

    library(vip)

    fit_rf %>% 
      extract_fit_parsnip() %>% 
      vip(num_features = 10)
    • 1
    plot of chunk unnamed-chunk-19
    plot of chunk unnamed-chunk-19

    所以,校准曲线的画法,你学会了吗?

    有问题欢迎评论区留言!

    加群即可免费获得示例数据!

    本文首发于公众号:医学和生信笔记

    医学和生信笔记,专注R语言在临床医学中的使用,R语言数据分析和可视化。主要分享R语言做医学统计学、meta分析、网络药理学、临床预测模型、机器学习、生物信息学等。

    本文由 mdnice 多平台发布

  • 相关阅读:
    专升本三本计科仔学习java到实习之路4
    视频号的链接在哪,视频号视频链接地址获取办法!
    把二叉搜索树转换为累加树
    Linux0.11——操作系统怎么把自己从硬盘搬到内存
    VSLAM视觉里程计总结
    关于Reactor模型,我们需要知道哪些内容
    【接口幂等性】使用token,Redis保证接口幂等性
    WebShell 木马免杀过WAF
    Redis-Redis持久化,主从哨兵架构详解
    编译openMVG出现的错误的解决
  • 原文地址:https://blog.csdn.net/Ayue0616/article/details/126869686