3  Tuning Hyperparameters and Overfitting

Set seed and load packages.

Show the code
set.seed(1337)

library("tidymodels")
tidymodels::tidymodels_prefer()
library("dials")

Load data.

Show the code
data("iris")
iris <- iris |>
  tibble::as_tibble() |> 
  filter(Species != "setosa") |> 
  droplevels()

iris_folds <- vfold_cv(iris)

3.1 Introduction to Hyperparameters

A classical and simple example of a hyperparameter is the number of neighbors, usually denoted k, in a k-nearest neighbors (KNN) algorithm. This is a hyperparameter as it is not estimated during model fitting, but is specified a priori making it impossible to optimize during parameter estimation.

3.2 Setting Up Tuning

In the tidymodels universe, hyperparameters are marked for tuning in the specifications for a model. To exemplify, both the number of nearest neighbors and a range of weight functions are tuned.

knn_spec <- nearest_neighbor(neighbors = tune(),
                             weight_func = tune()) |> 
  set_engine(engine = "kknn",
             trace = 0) |> 
  set_mode("classification")

Secondly, the recipe is set up. As no preprocessing is applied (e.g. log-transformation) it is quite simple.

knn_rec <- recipe(Species ~ ., # Use all other columns as features (called predictors in tidymodels)
                  data = iris)

The specs and recipe is then combined into a workflow:

knn_wflow <- workflow() |> 
  add_model(knn_spec) |> 
  add_recipe(knn_rec)

It is possible to inspect which hyperparameters are being tuned, check which values that are tested and change those values. This is done through the use of the dials package.

# Check hyperparameters
knn_spec |> extract_parameter_set_dials()
Collection of 2 parameters for tuning

  identifier        type    object
   neighbors   neighbors nparam[+]
 weight_func weight_func dparam[+]
# Check values tested
knn_spec |> extract_parameter_set_dials() |> 
  extract_parameter_dials("weight_func")
Distance Weighting Function  (qualitative)
10 possible values include:
'rectangular', 'triangular', 'epanechnikov', 'biweight', 'triweight', 'cos', ... 
# Change values, save in new object
knn_params <- knn_spec |>
  extract_parameter_set_dials() |>
  update(weight_func = weight_func(c("cos", "inv", "gaussian")),
         neighbors = neighbors(c(1, 15)))

# Check that it is updated
knn_params |>
  extract_parameter_dials("weight_func")
Distance Weighting Function  (qualitative)
3 possible values include:
'cos', 'inv' and 'gaussian' 
knn_params |>
  extract_parameter_dials("neighbors")
# Nearest Neighbors (quantitative)
Range: [1, 15]

Different grid_* functions exist to combine the hyperparameters, e.g. grid_random() and grid_regular(). As exemplified below, grid_regular() combines the parameters in all possible ways dependent on the number of levels chosen.

grid_regular(knn_params,
             levels = 4)
# A tibble: 12 × 2
   neighbors weight_func
       <int> <chr>      
 1         1 cos        
 2         5 cos        
 3        10 cos        
 4        15 cos        
 5         1 inv        
 6         5 inv        
 7        10 inv        
 8        15 inv        
 9         1 gaussian   
10         5 gaussian   
11        10 gaussian   
12        15 gaussian   

3.3 Measure Performance of Tuning

A metric is needed to measure the performance of the hyperparameters. The ROC curve is used. The regular grid is tuned:

# Performance metric
roc <- metric_set(roc_auc)

# Tuning
knn_tune <- knn_wflow |> 
  tune_grid(iris_folds,
            grid = knn_params |> grid_regular(levels = 4),
            metrics = roc)
knn_tune
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 4
   splits          id     .metrics          .notes          
   <list>          <chr>  <list>            <list>          
 1 <split [90/10]> Fold01 <tibble [12 × 6]> <tibble [0 × 3]>
 2 <split [90/10]> Fold02 <tibble [12 × 6]> <tibble [0 × 3]>
 3 <split [90/10]> Fold03 <tibble [12 × 6]> <tibble [0 × 3]>
 4 <split [90/10]> Fold04 <tibble [12 × 6]> <tibble [0 × 3]>
 5 <split [90/10]> Fold05 <tibble [12 × 6]> <tibble [0 × 3]>
 6 <split [90/10]> Fold06 <tibble [12 × 6]> <tibble [0 × 3]>
 7 <split [90/10]> Fold07 <tibble [12 × 6]> <tibble [0 × 3]>
 8 <split [90/10]> Fold08 <tibble [12 × 6]> <tibble [0 × 3]>
 9 <split [90/10]> Fold09 <tibble [12 × 6]> <tibble [0 × 3]>
10 <split [90/10]> Fold10 <tibble [12 × 6]> <tibble [0 × 3]>

To visualize the performance:

knn_tune |> 
  unnest(cols = .metrics) |> 
  select(id, .metric, neighbors, weight_func, .estimate) |>
  group_by(neighbors, weight_func) |> 
  mutate(estimate_avg = mean(.estimate)) |>
  ggplot(aes(x = neighbors,
             y = estimate_avg)) +
  geom_point(size = 3) +
  geom_line(linewidth = 0.7) +
  scale_x_continuous(breaks = c(1, 5, 10, 15)) +
  facet_wrap(~ weight_func) +
  theme(text=element_text(size=13))

3.4 Finalize Hyperparameter Selection

It would seem there is no visual difference between the weight functions. For the number of neighbors, the performance is highest for 10 and 15 neighbors. Preferably, the simplest of the two models is chosen.

final_hyperparams <- tibble(weight_func = "gaussian",
                            neighbors = 10)

final_knn_wflow <- knn_wflow |> 
  finalize_workflow(final_hyperparams)
final_knn_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = 10
  weight_func = gaussian

Engine-Specific Arguments:
  trace = 0

Computational engine: kknn 

The model can now be fit to the data and used for prediction.

final_knn_fit <- final_knn_wflow |> 
  fit(iris)
final_knn_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────

Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(10,     data, 5), kernel = ~"gaussian", trace = ~0)

Type of response variable: nominal
Minimal misclassification: 0.07
Best kernel: gaussian
Best k: 10

4 Session Info

sessioninfo::session_info()
─ Session info ───────────────────────────────────────────────────────────────
 setting  value
 version  R version 4.3.3 (2024-02-29 ucrt)
 os       Windows 11 x64 (build 22631)
 system   x86_64, mingw32
 ui       RTerm
 language (EN)
 collate  English_United Kingdom.utf8
 ctype    English_United Kingdom.utf8
 tz       Europe/Copenhagen
 date     2024-05-30
 pandoc   3.1.11 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version    date (UTC) lib source
 backports      1.4.1      2021-12-13 [1] CRAN (R 4.3.1)
 broom        * 1.0.5      2023-06-09 [1] CRAN (R 4.3.3)
 cachem         1.0.8      2023-05-01 [1] CRAN (R 4.3.3)
 class          7.3-22     2023-05-03 [2] CRAN (R 4.3.3)
 cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.3)
 codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.3)
 colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.3)
 conflicted     1.2.0      2023-02-01 [1] CRAN (R 4.3.3)
 data.table     1.15.4     2024-03-30 [1] CRAN (R 4.3.3)
 dials        * 1.2.1      2024-02-22 [1] CRAN (R 4.3.3)
 DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.3)
 digest         0.6.35     2024-03-11 [1] CRAN (R 4.3.3)
 dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.2)
 ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.3.3)
 evaluate       0.23       2023-11-01 [1] CRAN (R 4.3.3)
 fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.3)
 farver         2.1.1      2022-07-06 [1] CRAN (R 4.3.3)
 fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.3.3)
 foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.3)
 furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.3)
 future         1.33.2     2024-03-26 [1] CRAN (R 4.3.3)
 future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.3.3)
 generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.3)
 ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.3.3)
 globals        0.16.3     2024-03-08 [1] CRAN (R 4.3.3)
 glue           1.7.0      2024-01-09 [1] CRAN (R 4.3.3)
 gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.1)
 GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.3)
 gtable         0.3.5      2024-04-22 [1] CRAN (R 4.3.3)
 hardhat        1.3.1      2024-02-02 [1] CRAN (R 4.3.3)
 htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.3.3)
 htmlwidgets    1.6.4      2023-12-06 [1] CRAN (R 4.3.3)
 igraph         2.0.3      2024-03-13 [1] CRAN (R 4.3.3)
 infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.3.3)
 ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.3)
 iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.3)
 jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.3.3)
 kknn         * 1.3.1      2016-03-26 [1] CRAN (R 4.3.3)
 knitr          1.46       2024-04-06 [1] CRAN (R 4.3.3)
 labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.1)
 lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.3)
 lava           1.8.0      2024-03-05 [1] CRAN (R 4.3.3)
 lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.3)
 lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.3)
 listenv        0.9.1      2024-01-29 [1] CRAN (R 4.3.3)
 lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.3.3)
 magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.3)
 MASS           7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3)
 Matrix         1.6-5      2024-01-11 [2] CRAN (R 4.3.3)
 memoise        2.0.1      2021-11-26 [1] CRAN (R 4.3.3)
 modeldata    * 1.3.0      2024-01-21 [1] CRAN (R 4.3.3)
 munsell        0.5.1      2024-04-01 [1] CRAN (R 4.3.3)
 nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.3)
 parallelly     1.37.1     2024-02-29 [1] CRAN (R 4.3.3)
 parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.3.3)
 pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.3)
 pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.3)
 prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.3)
 purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.3)
 R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.3)
 Rcpp           1.0.12     2024-01-09 [1] CRAN (R 4.3.3)
 recipes      * 1.0.10     2024-02-18 [1] CRAN (R 4.3.3)
 rlang          1.1.3      2024-01-10 [1] CRAN (R 4.3.3)
 rmarkdown      2.26       2024-03-05 [1] CRAN (R 4.3.3)
 rpart          4.1.23     2023-12-05 [2] CRAN (R 4.3.3)
 rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.3.3)
 rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.3.3)
 scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.3.3)
 sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.3.3)
 survival       3.5-8      2024-02-14 [2] CRAN (R 4.3.3)
 tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.3)
 tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.3.3)
 tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.3.3)
 tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.3.3)
 timechange     0.3.0      2024-01-18 [1] CRAN (R 4.3.3)
 timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.2)
 tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.3.3)
 utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.3)
 vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.3)
 withr          3.0.0      2024-01-16 [1] CRAN (R 4.3.3)
 workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.3.3)
 workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.3.3)
 xfun           0.43       2024-03-25 [1] CRAN (R 4.3.3)
 yaml           2.3.8      2023-12-11 [1] CRAN (R 4.3.2)
 yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.3.3)

 [1] C:/Users/Willi/AppData/Local/R/win-library/4.3
 [2] C:/Program Files/R/R-4.3.3/library

──────────────────────────────────────────────────────────────────────────────