Show the code
set.seed(1337)
library("tidymodels")
tidymodels::tidymodels_prefer()
library("dials")Set seed and load packages.
set.seed(1337)
library("tidymodels")
tidymodels::tidymodels_prefer()
library("dials")Load data.
data("iris")
iris <- iris |>
tibble::as_tibble() |>
filter(Species != "setosa") |>
droplevels()
iris_folds <- vfold_cv(iris)A classical and simple example of a hyperparameter is the number of neighbors, usually denoted k, in a k-nearest neighbors (KNN) algorithm. This is a hyperparameter as it is not estimated during model fitting, but is specified a priori making it impossible to optimize during parameter estimation.
In the tidymodels universe, hyperparameters are marked for tuning in the specifications for a model. To exemplify, both the number of nearest neighbors and a range of weight functions are tuned.
knn_spec <- nearest_neighbor(neighbors = tune(),
weight_func = tune()) |>
set_engine(engine = "kknn",
trace = 0) |>
set_mode("classification")Secondly, the recipe is set up. As no preprocessing is applied (e.g. log-transformation) it is quite simple.
knn_rec <- recipe(Species ~ ., # Use all other columns as features (called predictors in tidymodels)
data = iris)The specs and recipe is then combined into a workflow:
knn_wflow <- workflow() |>
add_model(knn_spec) |>
add_recipe(knn_rec)It is possible to inspect which hyperparameters are being tuned, check which values that are tested and change those values. This is done through the use of the dials package.
# Check hyperparameters
knn_spec |> extract_parameter_set_dials()Collection of 2 parameters for tuning
identifier type object
neighbors neighbors nparam[+]
weight_func weight_func dparam[+]
# Check values tested
knn_spec |> extract_parameter_set_dials() |>
extract_parameter_dials("weight_func")Distance Weighting Function (qualitative)
10 possible values include:
'rectangular', 'triangular', 'epanechnikov', 'biweight', 'triweight', 'cos', ...
# Change values, save in new object
knn_params <- knn_spec |>
extract_parameter_set_dials() |>
update(weight_func = weight_func(c("cos", "inv", "gaussian")),
neighbors = neighbors(c(1, 15)))
# Check that it is updated
knn_params |>
extract_parameter_dials("weight_func")Distance Weighting Function (qualitative)
3 possible values include:
'cos', 'inv' and 'gaussian'
knn_params |>
extract_parameter_dials("neighbors")# Nearest Neighbors (quantitative)
Range: [1, 15]
Different grid_* functions exist to combine the hyperparameters, e.g. grid_random() and grid_regular(). As exemplified below, grid_regular() combines the parameters in all possible ways dependent on the number of levels chosen.
grid_regular(knn_params,
levels = 4)# A tibble: 12 × 2
neighbors weight_func
<int> <chr>
1 1 cos
2 5 cos
3 10 cos
4 15 cos
5 1 inv
6 5 inv
7 10 inv
8 15 inv
9 1 gaussian
10 5 gaussian
11 10 gaussian
12 15 gaussian
A metric is needed to measure the performance of the hyperparameters. The ROC curve is used. The regular grid is tuned:
# Performance metric
roc <- metric_set(roc_auc)
# Tuning
knn_tune <- knn_wflow |>
tune_grid(iris_folds,
grid = knn_params |> grid_regular(levels = 4),
metrics = roc)
knn_tune# Tuning results
# 10-fold cross-validation
# A tibble: 10 × 4
splits id .metrics .notes
<list> <chr> <list> <list>
1 <split [90/10]> Fold01 <tibble [12 × 6]> <tibble [0 × 3]>
2 <split [90/10]> Fold02 <tibble [12 × 6]> <tibble [0 × 3]>
3 <split [90/10]> Fold03 <tibble [12 × 6]> <tibble [0 × 3]>
4 <split [90/10]> Fold04 <tibble [12 × 6]> <tibble [0 × 3]>
5 <split [90/10]> Fold05 <tibble [12 × 6]> <tibble [0 × 3]>
6 <split [90/10]> Fold06 <tibble [12 × 6]> <tibble [0 × 3]>
7 <split [90/10]> Fold07 <tibble [12 × 6]> <tibble [0 × 3]>
8 <split [90/10]> Fold08 <tibble [12 × 6]> <tibble [0 × 3]>
9 <split [90/10]> Fold09 <tibble [12 × 6]> <tibble [0 × 3]>
10 <split [90/10]> Fold10 <tibble [12 × 6]> <tibble [0 × 3]>
To visualize the performance:
knn_tune |>
unnest(cols = .metrics) |>
select(id, .metric, neighbors, weight_func, .estimate) |>
group_by(neighbors, weight_func) |>
mutate(estimate_avg = mean(.estimate)) |>
ggplot(aes(x = neighbors,
y = estimate_avg)) +
geom_point(size = 3) +
geom_line(linewidth = 0.7) +
scale_x_continuous(breaks = c(1, 5, 10, 15)) +
facet_wrap(~ weight_func) +
theme(text=element_text(size=13))
It would seem there is no visual difference between the weight functions. For the number of neighbors, the performance is highest for 10 and 15 neighbors. Preferably, the simplest of the two models is chosen.
final_hyperparams <- tibble(weight_func = "gaussian",
neighbors = 10)
final_knn_wflow <- knn_wflow |>
finalize_workflow(final_hyperparams)
final_knn_wflow══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()
── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)
Main Arguments:
neighbors = 10
weight_func = gaussian
Engine-Specific Arguments:
trace = 0
Computational engine: kknn
The model can now be fit to the data and used for prediction.
final_knn_fit <- final_knn_wflow |>
fit(iris)
final_knn_fit══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()
── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps
── Model ───────────────────────────────────────────────────────────────────────
Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(10, data, 5), kernel = ~"gaussian", trace = ~0)
Type of response variable: nominal
Minimal misclassification: 0.07
Best kernel: gaussian
Best k: 10
sessioninfo::session_info()─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.3.3 (2024-02-29 ucrt)
os Windows 11 x64 (build 22631)
system x86_64, mingw32
ui RTerm
language (EN)
collate English_United Kingdom.utf8
ctype English_United Kingdom.utf8
tz Europe/Copenhagen
date 2024-05-30
pandoc 3.1.11 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
package * version date (UTC) lib source
backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.1)
broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.3)
cachem 1.0.8 2023-05-01 [1] CRAN (R 4.3.3)
class 7.3-22 2023-05-03 [2] CRAN (R 4.3.3)
cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.3)
codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.3)
colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.3)
conflicted 1.2.0 2023-02-01 [1] CRAN (R 4.3.3)
data.table 1.15.4 2024-03-30 [1] CRAN (R 4.3.3)
dials * 1.2.1 2024-02-22 [1] CRAN (R 4.3.3)
DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.3)
digest 0.6.35 2024-03-11 [1] CRAN (R 4.3.3)
dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.2)
ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.3)
evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.3)
fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.3)
farver 2.1.1 2022-07-06 [1] CRAN (R 4.3.3)
fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.3)
foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.3)
furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.3)
future 1.33.2 2024-03-26 [1] CRAN (R 4.3.3)
future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.3.3)
generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.3)
ggplot2 * 3.5.1 2024-04-23 [1] CRAN (R 4.3.3)
globals 0.16.3 2024-03-08 [1] CRAN (R 4.3.3)
glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.3)
gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.1)
GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.3)
gtable 0.3.5 2024-04-22 [1] CRAN (R 4.3.3)
hardhat 1.3.1 2024-02-02 [1] CRAN (R 4.3.3)
htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.3.3)
htmlwidgets 1.6.4 2023-12-06 [1] CRAN (R 4.3.3)
igraph 2.0.3 2024-03-13 [1] CRAN (R 4.3.3)
infer * 1.0.7 2024-03-25 [1] CRAN (R 4.3.3)
ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.3)
iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.3)
jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.3.3)
kknn * 1.3.1 2016-03-26 [1] CRAN (R 4.3.3)
knitr 1.46 2024-04-06 [1] CRAN (R 4.3.3)
labeling 0.4.3 2023-08-29 [1] CRAN (R 4.3.1)
lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.3)
lava 1.8.0 2024-03-05 [1] CRAN (R 4.3.3)
lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.3)
lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.3)
listenv 0.9.1 2024-01-29 [1] CRAN (R 4.3.3)
lubridate 1.9.3 2023-09-27 [1] CRAN (R 4.3.3)
magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.3)
MASS 7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3)
Matrix 1.6-5 2024-01-11 [2] CRAN (R 4.3.3)
memoise 2.0.1 2021-11-26 [1] CRAN (R 4.3.3)
modeldata * 1.3.0 2024-01-21 [1] CRAN (R 4.3.3)
munsell 0.5.1 2024-04-01 [1] CRAN (R 4.3.3)
nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.3)
parallelly 1.37.1 2024-02-29 [1] CRAN (R 4.3.3)
parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.3.3)
pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.3)
pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.3)
prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.3)
purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.3)
R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.3)
Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.3)
recipes * 1.0.10 2024-02-18 [1] CRAN (R 4.3.3)
rlang 1.1.3 2024-01-10 [1] CRAN (R 4.3.3)
rmarkdown 2.26 2024-03-05 [1] CRAN (R 4.3.3)
rpart 4.1.23 2023-12-05 [2] CRAN (R 4.3.3)
rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.3.3)
rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.3.3)
scales * 1.3.0 2023-11-28 [1] CRAN (R 4.3.3)
sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.3)
survival 3.5-8 2024-02-14 [2] CRAN (R 4.3.3)
tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.3)
tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.3.3)
tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.3.3)
tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.3.3)
timechange 0.3.0 2024-01-18 [1] CRAN (R 4.3.3)
timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.2)
tune * 1.2.1 2024-04-18 [1] CRAN (R 4.3.3)
utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.3)
vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.3)
withr 3.0.0 2024-01-16 [1] CRAN (R 4.3.3)
workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.3.3)
workflowsets * 1.1.0 2024-03-21 [1] CRAN (R 4.3.3)
xfun 0.43 2024-03-25 [1] CRAN (R 4.3.3)
yaml 2.3.8 2023-12-11 [1] CRAN (R 4.3.2)
yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.3.3)
[1] C:/Users/Willi/AppData/Local/R/win-library/4.3
[2] C:/Program Files/R/R-4.3.3/library
──────────────────────────────────────────────────────────────────────────────