• tidymodels用于机器学习的一些使用细节


    R语言做机器学习的当红辣子鸡R包:mlr3tidymodels,之前用十几篇推文详细介绍过mlr3

    今天学习下tidymodels的使用,其实之前在介绍临床预测模型时已经用过这个包了:

    但是对于很多没接触过这个包的朋友来说有些地方还是不好理解,所以今天专门写一篇推文介绍下tidymodels的一些使用细节,帮助大家更上一层楼。

    不得不说,比mlr3简单多了!

    设计理念

    tidymodels是max kuhn加入rstudio之后和Julia silge等人共同开发的机器学习R包,类似于mlr3caret,也是一个整合包,只提供统一的API,让大家可以通过统一的语法调用R语言里各种现成的机器学习算法R包,并不发明新的算法。

    这样做对用户来说最大的好处是不用记那么多R包的用法了,只需要记住tidymodels一个包的用法及参数就够了。同时得益于tidyverse系列的加持,在tidymodels中进行的各种操作以及产生的各种结果都是遵循tidy系列的设计理念的。所以非常有规律,很容易记住!

    tidymodels类似于tidyverse,是一系列R包的合集,其中主要的包括:

    • parsnip:提供统一的语法来选择模型(算法)
    • recipes:数据预处理
    • rsample:重抽样
    • dials:设置超参数
    • tune:调整超参数
    • yardstick:评价模型
    • broom:可以把各种模型的结果以整洁tibble格式返回,支持R语言所有内置模型!还有大部分第三方R包的模型!
    • infer:统计推断
    • workflows:联合数据预处理和算法

    除此之外,还包括ggplot2/purrr/dplyr/tibble等R包。

    真正在用的时候并不需要刻意记住,只需加载tidymodels就可得到全部~

    因为和tidyverse系列是一脉相承的,所以也是支持管道符的,这样的操作看起来非常的流程,也比较容易理解,对于初学者来说比mlr3那种面向对象的编程,简单多了。

    但是一个很大的问题是速度,因为底层也是基于tibble,所以速度没那么快,尤其是在调参的时候,非常慢,运算量一大就得好久时间才能出结果!

    安装

    目前发展还是很快,经常变更版本,所以时不时会遇到一些小问题,但总体来说瑕不掩瑜,学了不吃亏。

    # 2选1
    install.packages("tidymodels")

    library("devtools")
    install_github("tidymodels/tidymodels")
    • 1

    基本使用

    基本使用步骤和大家像想象中的差不多:

    • 选择算法(模型)
    • 数据预处理
    • 训练集建模
    • 测试集看效果

    在建模的过程中可能会同时出现重抽样、超参数调整等步骤,但基本步骤就是这样的。

    library(tidyverse)
    ## ── Attaching packages ───────────────────────────── tidyverse 1.3.2 ──
    ## ✔ ggplot2 3.3.6     ✔ purrr   0.3.4
    ## ✔ tibble  3.1.7     ✔ dplyr   1.0.9
    ## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
    ## ✔ readr   2.1.2     ✔ forcats 0.5.1
    ## ── Conflicts ──────────────────────────────── tidyverse_conflicts() ──
    ## ✖ dplyr::filter() masks stats::filter()
    ## ✖ dplyr::lag()    masks stats::lag()
    library(tidymodels)
    ## ── Attaching packages ──────────────────────────── tidymodels 1.0.0 ──
    ## ✔ broom        1.0.0     ✔ rsample      1.0.0
    ## ✔ dials        1.0.0     ✔ tune         1.0.0
    ## ✔ infer        1.0.2     ✔ workflows    1.0.0
    ## ✔ modeldata    1.0.0     ✔ workflowsets 1.0.0
    ## ✔ parsnip      1.0.0     ✔ yardstick    1.0.0
    ## ✔ recipes      1.0.1     
    ## ── Conflicts ─────────────────────────────── tidymodels_conflicts() ──
    ## ✖ scales::discard() masks purrr::discard()
    ## ✖ dplyr::filter()   masks stats::filter()
    ## ✖ recipes::fixed()  masks stringr::fixed()
    ## ✖ dplyr::lag()      masks stats::lag()
    ## ✖ yardstick::spec() masks readr::spec()
    ## ✖ recipes::step()   masks stats::step()
    ## • Use tidymodels_prefer() to resolve common conflicts.
    tidymodels_prefer() # 防止函数冲突
    • 1

    探索数据

    我们用一个结果变量是二分类变量的数据集来做一个简单的演示。这个数据集是关于大人住旅馆会不会带孩子一起。。。

    rm(list = ls())
    load(file = "../datasets/hotels_df.rdata")
    • 1

    简单查看一下数据:

    hotels_df |> glimpse()
    ## Rows: 75,166
    ## Columns: 10
    ## $ children                     none, none, none, none, none, none, none, …
    ## $ hotel                        Resort Hotel, Resort Hotel, Resort Hotel, …
    ## $ arrival_date_month           July, July, July, July, July, July, July, …
    ## $ meal                         BB, BB, BB, BB, BB, BB, BB, FB, HB, BB, HB…
    ## $ adr                          0.00, 0.00, 75.00, 75.00, 98.00, 98.00, 10…
    ## $ adults                       2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, …
    ## $ required_car_parking_spaces  none, none, none, none, none, none, none, …
    ## $ total_of_special_requests    0, 0, 0, 0, 1, 1, 0, 1, 0, 3, 1, 0, 3, 0, …
    ## $ stays_in_week_nights         0, 0, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, …
    ## $ stays_in_weekend_nights      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
    • 1

    这个数据一共10列,75166行,其中children这一列是结果变量,是二分类的,其余9列都是预测变量。

    我们的目的是用9列预测变量预测结果变量(感觉好绕啊)。。

    hotels_df |> count(children)
    ## # A tibble: 2 × 2
    ##   children     n
    ##       
    ## 1 children  6073
    ## 2 none     69093
    • 1

    可以看到结果变量两种分类很不均衡,差了10倍多!

    模型选择

    模型选择的部分需要大家记住tidymodels里面的一些名字,例如,对于决策树就是decision_tree(),大家可以去这个网址[1]查看所有支持的模型以及它们在tidymodels中的名字。

    模型选择在这里就是3步走:

    • 选择模型
    • 使用哪个R包
    • 回归还是分类(还有其他的自己看)
    tree_spec <- decision_tree() |> 
      set_engine("rpart") |> 
      set_mode("classification")
    • 1

    当然你如果没有其他需求,可以这样写:

    tree_spec <- decision_tree(engine = "rpart",mode = "classification")
    • 1

    效果一模一样!

    大家都知道很多算法都是有超参数的,R里面有很多R包都可以实现同一种算法,但是支持的超参数却不一样!

    所以,对于一些R包都有的超参数,大家可以把超参数写在选择模型这一步,对于一些R包特有的超参数(算法本身有但是其他包不支持)就要写在set_engine()这里面

    就像下面这样:

    rf_spec <- rand_forest(trees = 1000, min_n = 5) |> # 这两个参数大家都有
      set_engine("ranger", verbose = TRUE) |> # verbose参数只有ranger有,其他做随机森林的R包没有
      set_mode("classification")
    • 1

    数据划分

    tidymodels中数据划分非常简单。

    set.seed(12# 划分数据是随机的,设置种子数方便复现

    hotel_split <- hotels_df |>  
      initial_split(prop = 0.7# 需要根据某个变量分层只要加 strata = xxx即可

    hotel_train <- training(hotel_split) # 训练集
    hotel_test <- testing(hotel_split) # 测试集
    • 1

    这个是最常用的划分方法,还有很多,包括时间序列的划分等,大家可以自行学习。

    划分好的数据长这样:

    hotel_train |> glimpse()
    ## Rows: 52,616
    ## Columns: 10
    ## $ children                     none, none, none, none, none, none, none, …
    ## $ hotel                        Resort Hotel, City Hotel, City Hotel, City…
    ## $ arrival_date_month           December, March, April, July, August, June…
    ## $ meal                         BB, HB, SC, BB, SC, SC, HB, BB, BB, HB, BB…
    ## $ adr                          43.57, 53.50, 0.00, 139.51, 94.50, 72.25, …
    ## $ adults                       1, 2, 0, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 1, …
    ## $ required_car_parking_spaces  none, none, none, none, none, none, none, …
    ## $ total_of_special_requests    2, 1, 1, 1, 1, 0, 0, 1, 2, 0, 1, 0, 0, 1, …
    ## $ stays_in_week_nights         3, 1, 3, 3, 2, 3, 5, 5, 0, 1, 2, 3, 2, 1, …
    ## $ stays_in_weekend_nights      0, 2, 2, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 0, …
    • 1

    数据预处理

    为了让结果更准确,所以我们需要一些数据预处理步骤。

    首先就是这个结果变量的类不平衡,我们可以用downsample的方式解决,然后对于预测变量,我们需要对分类变量做哑变量处理,去除近零方差变量,还要对数值型变量标准化!

    这个事情在tidymodels中是这样操作的:

    hotel_rec <- recipe(children ~ ., data = hotel_train) |> 
      themis::step_downsample(children) |>
      step_dummy(all_nominal(), -all_outcomes()) |>
      step_zv(all_numeric()) |>
      step_normalize(all_numeric()) |>
      prep() # 最后一定别忘记这个
    • 1

    看起来非常舒服,简单易懂,一步一步的下来即可,并且采用了tidyselect的做法,支持all_nominal()这种选择语法,非常方便的选择想要执行操作的列。

    数据预处理之后,其实你不用把处理过的数据单独拿出来,就像之前介绍过的mlr3一样,可以直接进行到下一步训练模型,但是考虑到有些人就是要看到数据,你可以这样操作:

    # 提取处理好的训练集和测试集
    train_proc <- bake(hotel_rec, new_data = NULL# 训练集
    test_proc <- bake(hotel_rec, new_data = hotel_test) # 测试集

    train_proc |> glimpse()
    ## Rows: 8,486
    ## Columns: 23
    ## $ adr                                  -0.27844916, 1.95719142, 0.4760089…
    ## $ adults                               0.2314058, 0.2314058, 0.2314058, 0…
    ## $ total_of_special_requests            0.108877, 0.108877, 0.108877, -0.9…
    ## $ stays_in_week_nights                 1.3026010, -0.8268292, -1.3591868,…
    ## $ stays_in_weekend_nights              -0.99994313, -0.99994313, 1.038374…
    ## $ children                             children, children, children, chil…
    ## $ hotel_Resort.Hotel                   -0.830067, -0.830067, -0.830067, -…
    ## $ arrival_date_month_August            -0.4522368, -0.4522368, -0.4522368…
    ## $ arrival_date_month_December          -0.2502043, -0.2502043, -0.2502043…
    ## $ arrival_date_month_February          -0.2735849, -0.2735849, -0.2735849…
    ## $ arrival_date_month_January           -0.2244361, -0.2244361, -0.2244361…
    ## $ arrival_date_month_July              -0.4044283, 2.4723349, 2.4723349, …
    ## $ arrival_date_month_June              -0.2853468, -0.2853468, -0.2853468…
    ## $ arrival_date_month_March             -0.2785286, -0.2785286, -0.2785286…
    ## $ arrival_date_month_May               -0.2951308, -0.2951308, -0.2951308…
    ## $ arrival_date_month_November          -0.222693, -0.222693, -0.222693, -…
    ## $ arrival_date_month_October           -0.2948949, -0.2948949, -0.2948949…
    ## $ arrival_date_month_September         -0.2760646, -0.2760646, -0.2760646…
    ## $ meal_FB                              -0.08920343, -0.08920343, -0.08920…
    ## $ meal_HB                              -0.4054139, -0.4054139, -0.4054139…
    ## $ meal_SC                              -0.2440307, -0.2440307, -0.2440307…
    ## $ meal_Undefined                       -0.09998221, -0.09998221, -0.09998…
    ## $ required_car_parking_spaces_parking  -0.4002771, 2.4979750, -0.4002771,…
    • 1

    建立workflow

    这一步并不是必须要,建议对于有数据预处理步骤的,用workflow,如果没有数据预处理步骤,不用这一步更简单!

    tidymodels中增加了一个workflow函数,可以把模型选择和数据预处理这两部连接起来,形成一个对象,这个类似于mlr3的pipeline,但是只做这一件事!

    tree_wf <- workflow() |> 
      add_recipe(hotel_rec) |> 
      add_model(tree_spec) 
    • 1

    这里有多种方式构造workflow,但是一定要记住,add_model(xxxx)这一步是必须的!

    初次用这个的时候碰到很多问题,后来才发现,顺序、formula、variable等都是随便加就行,唯独add_model(xxxx)这一步必不可少!

    如果你熟练以后也可以这样写:

    tree_wf <- workflow(preprocessor = hotel_rec,
                        spec = tree_spec
                        )
    • 1

    这个workflow对象里面很多东西都是可以通过extract_xxx()提取的,但其实没啥用,一般情况下我们都知道自己前面干了什么。。

    tree_wf |> 
      extract_preprocessor()
    ## Recipe
    ## 
    ## Inputs:
    ## 
    ##       role #variables
    ##    outcome          1
    ##  predictor          9
    ## 
    ## Training data contained 52616 data points and no missing data.
    ## 
    ## Operations:
    ## 
    ## Down-sampling based on children [trained]
    ## Dummy variables from hotel, arrival_date_month, meal, required_car_parking_spaces [trained]
    ## Zero variance filter removed  [trained]
    ## Centering and scaling for adr, adults, total_of_special_requests, stays_i... [trained]

    tree_wf |> 
      extract_spec_parsnip()
    ## Decision Tree Model Specification (classification)
    ## 
    ## Computational engine: rpart
    • 1

    选择重抽样方法

    也是支持非常多的方法,常见的交叉验证,重复交叉验证,留一法,bootstrap,蒙特卡洛等,都是支持的。

    所有支持的重抽样方法可以在这里[2]查看。

    我们就选择一个简单的,10折交叉验证:

    set.seed(123)

    cv <- vfold_cv(hotel_train, v = 10)
    • 1

    训练模型(无重抽样)

    如果没有任何重抽样方法,那就非常简单了,直接fit(),然后再predict()就行了。

    给大家演示下:

    ## 建模
    tree_fit <- tree_wf |> 
      fit(hotel_train)

    # 测试集预测
    tree_pred <- predict(tree_fit, hotel_test)

    # 查看结果
    head(tree_pred)
    ## # A tibble: 6 × 1
    ##   .pred_class
    ##         
    ## 1 none       
    ## 2 none       
    ## 3 children   
    ## 4 none       
    ## 5 none       
    ## 6 none
    • 1

    如果是崭新的、没有结果变量的数据集,也是可以通过这种方式预测的:

    # 构造一个没有结果变量的数据集
    tmp <- hotel_test |> 
      select(-children) |> 
      slice_sample(n=5)

    glimpse(tmp)
    ## Rows: 5
    ## Columns: 9
    ## $ hotel                        City Hotel, Resort Hotel, City Hotel, Reso…
    ## $ arrival_date_month           May, October, February, June, May
    ## $ meal                         BB, BB, BB, HB, BB
    ## $ adr                          130.0, 46.5, 88.4, 88.7, 132.6
    ## $ adults                       1, 2, 2, 2, 1
    ## $ required_car_parking_spaces  none, none, none, none, none
    ## $ total_of_special_requests    0, 0, 0, 0, 0
    ## $ stays_in_week_nights         2, 0, 4, 8, 4
    ## $ stays_in_weekend_nights      0, 1, 0, 2, 2
    • 1

    预测结果只需要添加new_data = tmp即可:

    predict(tree_fit, new_data = tmp)
    ## # A tibble: 5 × 1
    ##   .pred_class
    ##         
    ## 1 children   
    ## 2 none       
    ## 3 none       
    ## 4 none       
    ## 5 children
    • 1

    得益于tidy系列的理念,这个predict()函数进行了很多优化。比如:

    现在很多R包的predict()用到的参数是不一样的: 各种predict

    所以用起来就很烦,经常不知道写什么,tidymodels也进行了统一,对于二分类变量来说,就是两个选项:

    • type = "prob"算概率
    • type = "class"算类别

    预测的结果也是有规律的:

    • 如果是数值型变量,那预测结果列名必定是 .pred
    • 如果是二分类变量,那预测结果列名必定是 .pred_class
      • 如果你选择了计算概率(prob),那结果列名就是 .pred_你的第一个分类.pred_你的第二个分类

    有了这个规律,用起来就方便多了。所以对于这种预测结果的评价,一般是和原来的真实结果结合起来,然后进行各种操作:

    tree_pred <- select(hotel_test, children) %>% 
      bind_cols(predict(tree_fit, hotel_test, type = "prob")) %>% 
      bind_cols(predict(tree_fit, hotel_test))

    head(tree_pred)
    ## # A tibble: 6 × 4
    ##   children .pred_children .pred_none .pred_class
    ##                             
    ## 1 none              0.251      0.749 none       
    ## 2 none              0.251      0.749 none       
    ## 3 none              0.583      0.417 children   
    ## 4 none              0.251      0.749 none       
    ## 5 none              0.251      0.749 none       
    ## 6 none              0.251      0.749 none
    • 1

    得到这个结果之后,就可以进行各种模型评价了:

    查看模型表现的操作也是非常遵循tidy理念的,模型评价是通过yardstick包实现的。比如下面这个AUC:

    tree_pred %>% roc_auc(truth = children, estimate = .pred_children)
    ## # A tibble: 1 × 3
    ##   .metric .estimator .estimate
    ##                
    ## 1 roc_auc binary         0.739
    • 1

    想要看什么指标直接写名字,一般都能自动补全出来,所有支持的指标可以在这里[3]查看。

    yardstick的第一个参数永远是你的数据集(tree_pred),第二个参数永远是真实结果,第三个参数永远是预测结果!

    可以说是非常的有规律了!

    # 选择多种评价指标
    metricsets <- metric_set(accuracy, mcc, f_meas, j_index)

    tree_pred %>% metricsets(truth = children, estimate = .pred_class)
    ## # A tibble: 4 × 3
    ##   .metric  .estimator .estimate
    ##                 
    ## 1 accuracy binary         0.692
    ## 2 mcc      binary         0.260
    ## 3 f_meas   binary         0.288
    ## 4 j_index  binary         0.455
    • 1

    可视化结果也是一模一样的设计理念:

    tree_pred %>% roc_curve(truth = children, estimate = .pred_children) %>% 
      autoplot()
    • 1
    ROC
    ROC

    训练模型(有重抽样)

    不过我们是有交叉验证这一步的,下面就来演示~

    在训练集中训练模型,因为这个算法不复杂,我们也没进行特别复杂的操作,所以还是很快的,在我电脑上大概2秒钟。。。

    # 控制计算过程的一些设置
    keep_pred <- control_resamples(save_pred = T, verbose = T)

    set.seed(456)

    library(doParallel) 
    cl <- makePSOCKcluster(12# 加速,用12个线程
    registerDoParallel(cl)

    tree_res <- fit_resamples(tree_wf, 
                              resamples = cv, 
                              control = keep_pred)
    ## i Fold01: preprocessor 1/1
    ## ✓ Fold01: preprocessor 1/1
    ## i Fold01: preprocessor 1/1, model 1/1
    ## ✓ Fold01: preprocessor 1/1, model 1/1
    ## i Fold01: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold02: preprocessor 1/1
    ## ✓ Fold02: preprocessor 1/1
    ## i Fold02: preprocessor 1/1, model 1/1
    ## ✓ Fold02: preprocessor 1/1, model 1/1
    ## i Fold02: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold03: preprocessor 1/1
    ## ✓ Fold03: preprocessor 1/1
    ## i Fold03: preprocessor 1/1, model 1/1
    ## ✓ Fold03: preprocessor 1/1, model 1/1
    ## i Fold03: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold04: preprocessor 1/1
    ## ✓ Fold04: preprocessor 1/1
    ## i Fold04: preprocessor 1/1, model 1/1
    ## ✓ Fold04: preprocessor 1/1, model 1/1
    ## i Fold04: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold05: preprocessor 1/1
    ## ✓ Fold05: preprocessor 1/1
    ## i Fold05: preprocessor 1/1, model 1/1
    ## ✓ Fold05: preprocessor 1/1, model 1/1
    ## i Fold05: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold06: preprocessor 1/1
    ## ✓ Fold06: preprocessor 1/1
    ## i Fold06: preprocessor 1/1, model 1/1
    ## ✓ Fold06: preprocessor 1/1, model 1/1
    ## i Fold06: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold07: preprocessor 1/1
    ## ✓ Fold07: preprocessor 1/1
    ## i Fold07: preprocessor 1/1, model 1/1
    ## ✓ Fold07: preprocessor 1/1, model 1/1
    ## i Fold07: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold08: preprocessor 1/1
    ## ✓ Fold08: preprocessor 1/1
    ## i Fold08: preprocessor 1/1, model 1/1
    ## ✓ Fold08: preprocessor 1/1, model 1/1
    ## i Fold08: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold09: preprocessor 1/1
    ## ✓ Fold09: preprocessor 1/1
    ## i Fold09: preprocessor 1/1, model 1/1
    ## ✓ Fold09: preprocessor 1/1, model 1/1
    ## i Fold09: preprocessor 1/1, model 1/1 (predictions)
    ## i Fold10: preprocessor 1/1
    ## ✓ Fold10: preprocessor 1/1
    ## i Fold10: preprocessor 1/1, model 1/1
    ## ✓ Fold10: preprocessor 1/1, model 1/1
    ## i Fold10: preprocessor 1/1, model 1/1 (predictions)
    stop(cl)
    • 1

    查看模型表现,不管你换什么模型、什么数据集,结果的列名都是这几个,比如.metric\.estimator这些,这也是tidy的理念~

    tree_res |>  
      collect_metrics()
    ## # A tibble: 2 × 6
    ##   .metric  .estimator  mean     n std_err .config             
    ##                                 
    ## 1 accuracy binary     0.727    10 0.00830 Preprocessor1_Model1
    ## 2 roc_auc  binary     0.739    10 0.00322 Preprocessor1_Model1
    • 1

    想要查看每一折的表现也是可以的:

    tree_res |> 
      collect_metrics(summarize=F)
    ## # A tibble: 20 × 5
    ##    id     .metric  .estimator .estimate .config             
    ##                                    
    ##  1 Fold01 accuracy binary         0.740 Preprocessor1_Model1
    ##  2 Fold01 roc_auc  binary         0.743 Preprocessor1_Model1
    ##  3 Fold02 accuracy binary         0.757 Preprocessor1_Model1
    ##  4 Fold02 roc_auc  binary         0.732 Preprocessor1_Model1
    ##  5 Fold03 accuracy binary         0.730 Preprocessor1_Model1
    ##  6 Fold03 roc_auc  binary         0.727 Preprocessor1_Model1
    ##  7 Fold04 accuracy binary         0.675 Preprocessor1_Model1
    ##  8 Fold04 roc_auc  binary         0.723 Preprocessor1_Model1
    ##  9 Fold05 accuracy binary         0.719 Preprocessor1_Model1
    ## 10 Fold05 roc_auc  binary         0.755 Preprocessor1_Model1
    ## 11 Fold06 accuracy binary         0.720 Preprocessor1_Model1
    ## 12 Fold06 roc_auc  binary         0.747 Preprocessor1_Model1
    ## 13 Fold07 accuracy binary         0.719 Preprocessor1_Model1
    ## 14 Fold07 roc_auc  binary         0.747 Preprocessor1_Model1
    ## 15 Fold08 accuracy binary         0.753 Preprocessor1_Model1
    ## 16 Fold08 roc_auc  binary         0.743 Preprocessor1_Model1
    ## 17 Fold09 accuracy binary         0.758 Preprocessor1_Model1
    ## 18 Fold09 roc_auc  binary         0.743 Preprocessor1_Model1
    ## 19 Fold10 accuracy binary         0.702 Preprocessor1_Model1
    ## 20 Fold10 roc_auc  binary         0.731 Preprocessor1_Model1
    • 1

    查看具体的结果,这个结果的列名也是很有规律的:

    • 第一列永远是 id
    • 第二列是 .pred_你的第一个分类
    • 第三列是 .pred_你的第二个分类
    • 第四列是 .pred_xxx,其中 xxx是你的结果变量的列名。
    tree_res |> 
      collect_predictions()
    ## # A tibble: 52,616 × 7
    ##    id     .pred_children .pred_none  .row .pred_class children .config          
    ##                                              
    ##  1 Fold01          0.267      0.733    19 none        none     Preprocessor1_Mo…
    ##  2 Fold01          0.267      0.733    23 none        none     Preprocessor1_Mo…
    ##  3 Fold01          0.267      0.733    29 none        none     Preprocessor1_Mo…
    ##  4 Fold01          0.756      0.244    39 children    children Preprocessor1_Mo…
    ##  5 Fold01          0.267      0.733    51 none        none     Preprocessor1_Mo…
    ##  6 Fold01          0.267      0.733    69 none        none     Preprocessor1_Mo…
    ##  7 Fold01          0.267      0.733    79 none        none     Preprocessor1_Mo…
    ##  8 Fold01          0.267      0.733    86 none        none     Preprocessor1_Mo…
    ##  9 Fold01          0.607      0.393    91 children    none     Preprocessor1_Mo…
    ## 10 Fold01          0.756      0.244   112 children    none     Preprocessor1_Mo…
    ## # … with 52,606 more rows
    ## # ℹ Use `print(n = ...)` to see more rows
    • 1

    如果你有调参的过程,这里又会多好几步,主要是用来选择合适的超参数,但是我们没有这一步。

    用于测试集

    注意这里不是直接predict()哦,而是用last_fit()这个函数,而且它的第二个参数不是测试集,而是hotel_split

    tree_pred <- last_fit(tree_wf, hotel_split)
    • 1

    你想探索这个测试集的模型表现,也是和上面一样的:

    tree_pred |> collect_metrics()
    ## # A tibble: 2 × 4
    ##   .metric  .estimator .estimate .config             
    ##                                 
    ## 1 accuracy binary         0.692 Preprocessor1_Model1
    ## 2 roc_auc  binary         0.739 Preprocessor1_Model1
    • 1
    test_pred <- tree_pred |> collect_predictions()
    head(test_pred)
    ## # A tibble: 6 × 7
    ##   id               .pred_children .pred_none  .row .pred_class children .config 
    ##                                              
    ## 1 train/test split          0.251      0.749     4 none        none     Preproc…
    ## 2 train/test split          0.251      0.749     8 none        none     Preproc…
    ## 3 train/test split          0.583      0.417    10 children    none     Preproc…
    ## 4 train/test split          0.251      0.749    12 none        none     Preproc…
    ## 5 train/test split          0.251      0.749    18 none        none     Preproc…
    ## 6 train/test split          0.251      0.749    21 none        none     Preproc…
    roc_auc(test_pred, truth = children, .estimate = .pred_children)
    ## # A tibble: 1 × 3
    ##   .metric .estimator .estimate
    ##                
    ## 1 roc_auc binary         0.739
    • 1

    进阶

    以上是关于tidymodels的基础使用,大家在实际使用中经常会遇到更加复杂的情况,比如:多个模型的比较,多个模型在多个数据集并配合不同的预处理步骤,超参数调优等等。

    关于多个模型比较的部分大家可以翻看我之前的推文。

    另外,还可以去我的个人博客:https://www.yuque.com/ayueme , 查看更多内容,我的博客里给出了非常多tidymodels使用的例子,这些内容目前还没有搬到公众号上来,可以帮助大家更进一步了解这个包。

    总结

    总体来看,tidymodels在统一使用方式方面做的非常棒,各个步骤中都有tidy理念的影子,这样一旦你熟悉了其基本语法,使用起来是很舒服的,因为代码基本不用变,连列名都是固定的!

    有点难度的地方是数据预处理步骤,因为太多了,所有的预处理步骤大家可以去这里[4]看。

    另外,对于超参数调优的部分感觉不如mlr3做得好,很多超参数的名字、类型、取值等很难记住,并且没有明确给出查看这些信息的函数,经常要不断的用?xxx来看帮助文档。。。

    还有一个就是速度,基于tibble,并且各种fit_xxx()函数也是基于purrr包,这就导致它速度一般。但是目前我还没接触到需要好几个小时的数据,一般也就顶多半小时!

    如果你是新手,建议你先学tidymodels,因为简单,mlr3的R6语法太反人类了。。。

    今日示例数据已上传QQ群,需要的加群自取即可↓

    参考资料

    [1]

    支持的模型: https://www.tidymodels.org/find/parsnip/

    [2]

    重抽样方法: https://rsample.tidymodels.org/reference/index.html

    [3]

    评价指标: https://yardstick.tidymodels.org/reference/index.html

    [4]

    预处理: https://recipes.tidymodels.org/reference/index.html

    本文由 mdnice 多平台发布

  • 相关阅读:
    分布式原理
    WR | 水源水耐药基因稳定赋存的关键:以致病菌为“源”,群落构建主导菌为“汇”...
    计算机视觉基础(9)——相机标定与对极几何
    《深度学习进阶 自然语言处理》第五章:RNN通俗介绍
    微服务到底该怎么样部署呢?
    『忘了再学』Shell基础 — 12、用户自定义变量
    uniapp相关记录
    这些好用的设计素材网,你一定要知道。
    一条Select语句在MySQL-Server层的执行过程
    Python 和Java 哪个更适合做自动化测试?
  • 原文地址:https://blog.csdn.net/Ayue0616/article/details/126452500