These functions can be used to create a machine learning model based on different 'engines' and to generalise predicting outcomes based on such models. These functions are wrappers around tidymodels packages (especially parsnip, recipes, rsample, tune, and yardstick) created by RStudio.

ml_xg_boost(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "xgboost",
  mode = c("classification", "regression", "unknown"),
  trees = 15,
  ...
)

ml_decision_trees(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "rpart",
  mode = c("classification", "regression", "unknown"),
  tree_depth = 30,
  ...
)

ml_random_forest(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "ranger",
  mode = c("classification", "regression", "unknown"),
  trees = 500,
  ...
)

ml_neural_network(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "nnet",
  mode = c("classification", "regression", "unknown"),
  penalty = 0,
  epochs = 100,
  ...
)

ml_nearest_neighbour(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "kknn",
  mode = c("classification", "regression", "unknown"),
  neighbors = 5,
  weight_func = "triangular",
  ...
)

ml_linear_regression(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "lm",
  mode = "regression",
  ...
)

ml_logistic_regression(
  .data,
  outcome,
  predictors = everything(),
  training_fraction = 0.75,
  strata = NULL,
  na_threshold = 0.01,
  correlation_threshold = 0.9,
  centre = TRUE,
  scale = TRUE,
  engine = "glm",
  mode = "classification",
  penalty = 0.1,
  ...
)

# S3 method for class 'certestats_ml'
confusion_matrix(data, ...)

# S3 method for class 'certestats_ml'
predict(object, new_data, type = NULL, ...)

apply_model_to(
  object,
  new_data,
  add_certainty = TRUE,
  only_prediction = FALSE,
  correct_mistakes = TRUE,
  impute_algorithm = "mice",
  ...
)

feature_importances(object, ...)

feature_importance_plot(object, ...)

roc_plot(object, ...)

gain_plot(object, ...)

tree_plot(object, ...)

correlation_plot(
  data,
  add_values = TRUE,
  cols = everything(),
  correlation_threshold = 0.9
)

get_metrics(object)

get_accuracy(object)

get_kappa(object)

get_recipe(object)

get_specification(object)

get_rows_testing(object)

get_rows_training(object)

get_original_data(object)

get_roc_data(object)

get_coefficients(object)

get_model_variables(object)

get_variable_weights(object)

tune_parameters(object, ..., only_params_in_model = FALSE, levels = 5, k = 10)

check_testing_predictions(object)

# S3 method for class 'certestats_ml'
autoplot(object, plot_type = "roc", ...)

# S3 method for class 'certestats_feature_importances'
autoplot(object, ...)

# S3 method for class 'certestats_tuning'
autoplot(object, type = c("marginals", "parameters", "performance"), ...)

Arguments

.data

Data set to train

outcome

Outcome variable, also called the response variable or the dependent variable; the variable that must be predicted. The value will be evaluated in select() and thus supports the tidyselect language. In case of classification prediction, this variable will be coerced to a factor.

predictors

Explanatory variables, also called the predictors or the independent variables; the variables that are used to predict outcome. These variables will be transformed using as.double() (factors will be transformed to characters first). This value defaults to everything() and supports the tidyselect language.

training_fraction

Fraction of rows to be used for training, defaults to 75%. The rest will be used for testing. If given a number over 1, the number will be considered to be the required number of rows for training.

strata

A variable in data (single character or name) used to conduct stratified sampling. When not NULL, each resample is created within the stratification variable. Numeric strata are binned into quartiles.

na_threshold

Maximum fraction of NA values (defaults to 0.01) of the predictors before they are removed from the model, using recipes::step_select()

correlation_threshold

A value (default 0.9) to indicate the correlation threshold. Predictors with a correlation higher than this value with be removed from the model, using recipes::step_corr()

centre

A logical to indicate whether the predictors should be transformed so that their mean will be 0, using recipes::step_center(). Binary columns will be skipped.

scale

A logical to indicate whether the predictors should be transformed so that their standard deviation will be 1, using recipes::step_scale(). Binary columns will be skipped.

engine

R package or function name to be used for the model, will be passed on to parsnip::set_engine()

mode

Type of predicted value - defaults to "classification", but can also be "unknown" or "regression"

trees

An integer for the number of trees contained in the ensemble.

...

Arguments to be passed on to the parsnip functions, see Model Functions.

For the tune_parameters() function, these must be dials package calls, such as dials::trees() (see Examples).

For predict(), these must be arguments passed on to parsnip::predict.model_fit()

tree_depth

An integer for maximum depth of the tree.

penalty

A non-negative number representing the total amount of regularization (specific engines only).

epochs

An integer for the number of training iterations.

neighbors

A single integer for the number of neighbors to consider (often called k). For kknn, a value of 5 is used if neighbors is not specified.

weight_func

A single character for the type of kernel function used to weight distances between samples. Valid choices are: "rectangular", "triangular", "epanechnikov", "biweight", "triweight", "cos", "inv", "gaussian", "rank", or "optimal".

object, data

outcome of machine learning model

new_data

A rectangular data object, such as a data frame.

type

A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time", "hazard", "survival", or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

add_certainty

a logical to indicate whether certainties should be added to the output data.frame

only_prediction

a logical to indicate whether predictions must be returned as vector, otherwise returns a data.frame

correct_mistakes

a logical to indicate whether missing variables and missing values should be added to new_data

impute_algorithm

the algorithm to use in impute() if correct_mistakes = TRUE. Can be "mice" (default) for the Multivariate Imputations by Chained Equations (MICE) algorithm, or "single-point" for a trained median.

add_values

a logical to indicate whether values must be printed in the tiles

cols

columns to use for correlation plot, defaults to everything()

only_params_in_model

a logical to indicate whether only parameters in the model should be tuned

levels

An integer for the number of values of each parameter to use to make the regular grid. levels can be a single integer or a vector of integers that is the same length as the number of parameters in .... levels can be a named integer vector, with names that match the id values of parameters.

k

The number of partitions of the data set

plot_type

the plot type, can be "roc" (default), "gain", "lift" or "pr". These functions rely on yardstick::roc_curve(), yardstick::gain_curve(), yardstick::lift_curve() and yardstick::pr_curve() to construct the curves.

Value

A machine learning model of class certestats_ml / ... / model_fit.

Details

To predict regression (numeric values), the function ml_logistic_regression() cannot be used.

To predict classifications (character values), the function ml_linear_regression() cannot be used.

The workflow of the ml_*() functions is basically like this (thus saving a lot of tidymodels functions to type):


                       .data
                         |
               rsample::initial_split()
                     /        \
     rsample::training() rsample::testing()
             |                |
       recipe::recipe()       |
             |                |
      recipe::step_corr()     |
             |                |
     recipe::step_center()    |
             |                |
      recipe::step_scale()    |
             |                |
        recipe::prep()        |
         /           \        |
recipes::bake()       recipes::bake()
       |                      |
generics::fit()      yardstick::metrics()
       |                      |
    output            attributes(output)

The predict() function can be used to fit a model on a new data set. Its wrapper apply_model_to() works in the same way, but can also detect and fix missing variables, missing data points, and data type differences between the trained data and the input data.

Use feature_importances() to get the importance of all features/variables. Use autoplot() afterwards to plot the results. These two functions are combined in feature_importance_plot().

Use correlation_plot() to plot the correlation between all variables, even characters. If the input is a certestats ML model, the training data of the model will be plotted.

Use the get_model_variables() function to return a zero-row data.frame with the variables that were used for training, even before the recipe steps.

Use the get_variable_weights() function to determine the (rough) estimated weights of each variable in the model. This is not as reliable as retrieving coefficients, but it does work for any model. The weights are determined by running the model over all the highest and lowest values of each variable in the trained data. The function returns a data set with 1 row, of which the values sum up to 1.

Use the tune_parameters() function to analyse tune parameters of any ml_*() function. Without any parameters manually defined, it will try to tune all parameters of the underlying ML model. The tuning will be based on a K-fold cross-validation, of which the number of partitions can be set with k. The number of levels will be used to split the range of the parameters. For example, a range of 1-10 with levels = 2 will lead to [1, 10], while levels = 5 will lead to [1, 3, 5, 7, 9]. The resulting data.frame will be sorted from best to worst. These results can also be plotted using autoplot().

The check_testing_predictions() function combines the data used for testing from the original data with its predictions, so the original data can be reviewed per prediction.

Use autoplot() on a model to plot the receiver operating characteristic (ROC) curve, the gain curve, the lift curve, or the precision-recall (PR) curve. For the ROC curve, the (overall) area under the curve (AUC) will be printed as subtitle.

Attributes

The ml_*() functions return the following attributes:

  • properties: a list with model properties: the ML function, engine package, training size, testing size, strata size, mode, and the different ML function-specific properties (such as tree_depth in ml_decision_trees())

  • recipe: a recipe as generated with recipes::prep(), to be used for training and testing

  • data_original: a data.frame containing the original data, possibly without invalid strata

  • data_structure: a data.frame containing the original data structure (only trained variables) with zero rows

  • data_means: a data.frame containing the means of the original data (only trained variables)

  • data_training: a data.frame containing the training data of data_original

  • data_testing: a data.frame containing the testing data of data_original

  • rows_training: an integer vector of rows used for training in data_original

  • rows_testing: an integer vector of rows used for training in data_original

  • predictions: a data.frame containing predicted values based on the testing data

  • metrics: a data.frame with model metrics as returned by yardstick::metrics()

  • correlation_threshold: a logical indicating whether recipes::step_corr() has been applied

  • centre: a logical indicating whether recipes::step_center() has been applied

  • scale: a logical indicating whether recipes::step_scale() has been applied

Model Functions

These are the called functions from the parsnip package. Arguments set in ... will be passed on to these parsnip functions:

Examples

# 'esbl_tests' is an included data set, see ?esbl_tests
print(esbl_tests, n = 5)
#> # A tibble: 500 × 19
#>   esbl  genus    AMC   AMP   TZP   CXM   FOX   CTX   CAZ   GEN   TOB   TMP   SXT
#>   <lgl> <chr>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 FALSE Esche…    32    32     4    64    64  8     8        1     1  16      20
#> 2 FALSE Esche…    32    32     4    64    64  4     8        1     1  16     320
#> 3 FALSE Esche…     4     2    64     8     4  8     0.12    16    16   0.5    20
#> 4 FALSE Klebs…    32    32    16    64    64  8     8        1     1   0.5    20
#> 5 FALSE Esche…    32    32     4     4     4  0.25  2        1     1  16     320
#> # ℹ 495 more rows
#> # ℹ 6 more variables: NIT <dbl>, FOS <dbl>, CIP <dbl>, IPM <dbl>, MEM <dbl>,
#> #   COL <dbl>

esbl_tests |> correlation_plot(add_values = FALSE) # red will be removed from model


# predict ESBL test outcome based on MICs using 2 different models
model1 <- esbl_tests |> ml_xg_boost(esbl, where(is.double))
model2 <- esbl_tests |> ml_decision_trees(esbl, where(is.double))


# Assessing A Model ----------------------------------------------------

model1 |> get_metrics()
#>    .metric .estimator .estimate
#> 1 accuracy     binary 0.9680000
#> 2      kap     binary 0.9359795
model2 |> get_metrics()
#>    .metric .estimator .estimate
#> 1 accuracy     binary 0.9280000
#> 2      kap     binary 0.8558616

model1 |> confusion_matrix()
#> 
#> ── Confusion Matrix ────────────────────────────────────────────────────────────
#>        
#>         TRUE FALSE
#>   TRUE    62     1
#>   FALSE    3    59
#> 
#> 
#> ── Model Metrics ───────────────────────────────────────────────────────────────
#> 
#> Accuracy                                           0.968
#> Area under the Precision Recall Curve (APRC)       0.998
#> Area under the Receiver Operator Curve (AROC)      0.998
#> Balanced Accuracy                                  0.968
#> Brier Score for Classification Models (BSCM)       0.020
#> Concordance Correlation Coefficient (CCC)         -0.309
#> Costs Function for Poor Classification (CFPC)      0.065
#> F Measure                                          0.969
#> Gain Capture                                       0.996
#> Huber Loss                                         0.709
#> Index of Ideality of Correlation (IIC)               NaN
#> J-Index                                            0.936
#> Kappa                                              0.936
#> Matthews Correlation Coefficient (MCC)             0.936
#> Mean Absolute Error (MAE)                          0.981
#> Mean Absolute Percent Error (MAPE)                50.377
#> Mean Absolute Scaled Error (MASE)                121.592
#> Mean log Loss for Multinomial Data (MLMD)          0.083
#> Mean log Loss for Poisson Data (MLPD)              4.253
#> Mean Percentage Error (MPE)                       50.377
#> Mean Signed Deviation (MSD)                        0.981
#> Negative Predictive Value (NPV)                    0.983
#> Positive Predictive Value (PPV)                    0.954
#> Precision                                          0.954
#> Prevalence                                         0.520
#> Psuedo-Huber Loss (PHL)                            0.583
#> R Squared                                          0.924
#> R Squared - Traditional (RST)                     -6.402
#> Ratio of Performance to Deviation (RPD)            0.369
#> Ratio of Performance to Inter-Quartile (RPIQ)      0.735
#> Recall                                             0.984
#> Root Mean Squared Error (RMSE)                     1.360
#> Sensitivity                                        0.984
#> Specificity                                        0.952
#> Symmetric Mean Absolute Percentage Error (SMAPE)  95.769

# a correlation plot of a model shows the training data
model1 |> correlation_plot(add_values = FALSE)


model1 |> feature_importances()
#> # A tibble: 14 × 5
#>    feature    gain   cover frequency importance
#>  * <chr>     <dbl>   <dbl>     <dbl>      <dbl>
#>  1 CTX     0.666   0.324     0.152      0.495  
#>  2 FOX     0.131   0.163     0.137      0.139  
#>  3 TZP     0.0236  0.135     0.152      0.0715 
#>  4 GEN     0.0502  0.0970    0.0508     0.0597 
#>  5 CIP     0.0331  0.0470    0.117      0.0526 
#>  6 NIT     0.0259  0.0422    0.0761     0.0392 
#>  7 IPM     0.0175  0.0829    0.0558     0.0383 
#>  8 CAZ     0.0148  0.0433    0.102      0.0378 
#>  9 TMP     0.0250  0.0178    0.0609     0.0308 
#> 10 AMC     0.00437 0.0178    0.0305     0.0123 
#> 11 FOS     0.00317 0.0133    0.0254     0.00964
#> 12 TOB     0.00141 0.00480   0.0203     0.00587
#> 13 CXM     0.00267 0.00598   0.0152     0.00584
#> 14 COL     0.00131 0.00650   0.00508    0.00310
model1 |> feature_importances() |> autoplot()

model2 |> feature_importance_plot()


# decision trees can also have a tree plot
model2 |> tree_plot()



# Applying A Model -----------------------------------------------------
 
# simply use base R `predict()` to apply a model:
model1 |> predict(esbl_tests)
#> # A tibble: 500 × 1
#>    .pred_class
#>    <fct>      
#>  1 FALSE      
#>  2 FALSE      
#>  3 FALSE      
#>  4 FALSE      
#>  5 FALSE      
#>  6 FALSE      
#>  7 FALSE      
#>  8 FALSE      
#>  9 FALSE      
#> 10 FALSE      
#> # ℹ 490 more rows

# but apply_model_to() contains more info and can apply corrections:
model1 |> apply_model_to(esbl_tests)
#> # A tibble: 500 × 4
#>    predicted certainty .pred_TRUE .pred_FALSE
#>    <lgl>         <dbl>      <dbl>       <dbl>
#>  1 FALSE         0.990    0.00958       0.990
#>  2 FALSE         0.990    0.00958       0.990
#>  3 FALSE         0.655    0.345         0.655
#>  4 FALSE         0.984    0.0157        0.984
#>  5 FALSE         0.986    0.0141        0.986
#>  6 FALSE         0.924    0.0759        0.924
#>  7 FALSE         0.982    0.0185        0.982
#>  8 FALSE         0.985    0.0150        0.985
#>  9 FALSE         0.987    0.0133        0.987
#> 10 FALSE         0.946    0.0539        0.946
#> # ℹ 490 more rows
model1 |> apply_model_to(esbl_tests[, 1:15])
#> # A tibble: 500 × 4
#>    predicted certainty .pred_TRUE .pred_FALSE
#>    <lgl>         <dbl>      <dbl>       <dbl>
#>  1 FALSE         0.990    0.00958       0.990
#>  2 FALSE         0.990    0.00958       0.990
#>  3 TRUE          0.676    0.676         0.324
#>  4 FALSE         0.984    0.0157        0.984
#>  5 FALSE         0.982    0.0177        0.982
#>  6 FALSE         0.924    0.0759        0.924
#>  7 FALSE         0.978    0.0215        0.978
#>  8 FALSE         0.649    0.351         0.649
#>  9 FALSE         0.987    0.0133        0.987
#> 10 FALSE         0.983    0.0174        0.983
#> # ℹ 490 more rows
esbl_tests2 <- esbl_tests
esbl_tests2[2, "CIP"] <- NA
esbl_tests2[5, "AMC"] <- NA
# with XGBoost, nothing will be changed (it can correct for missings):
model1 |> apply_model_to(esbl_tests2)
#> # A tibble: 500 × 4
#>    predicted certainty .pred_TRUE .pred_FALSE
#>    <lgl>         <dbl>      <dbl>       <dbl>
#>  1 FALSE         0.990    0.00958       0.990
#>  2 FALSE         0.990    0.00958       0.990
#>  3 FALSE         0.655    0.345         0.655
#>  4 FALSE         0.984    0.0157        0.984
#>  5 FALSE         0.973    0.0275        0.973
#>  6 FALSE         0.924    0.0759        0.924
#>  7 FALSE         0.982    0.0185        0.982
#>  8 FALSE         0.985    0.0150        0.985
#>  9 FALSE         0.987    0.0133        0.987
#> 10 FALSE         0.946    0.0539        0.946
#> # ℹ 490 more rows
# with random forest (or others), missings will be imputed:
model2 |> apply_model_to(esbl_tests2)
#> Generating MICE using m = 5 multiple imputations... 
#> OK.
#> Imputed variable 'AMC' using MICE (method: predictive mean matching) in row 5
#> Imputed variable 'CIP' using MICE (method: predictive mean matching) in row 2
#> # A tibble: 500 × 4
#>    predicted certainty .pred_TRUE .pred_FALSE
#>    <lgl>         <dbl>      <dbl>       <dbl>
#>  1 FALSE         0.940     0.0595       0.940
#>  2 FALSE         0.940     0.0595       0.940
#>  3 TRUE          0.769     0.769        0.231
#>  4 FALSE         0.940     0.0595       0.940
#>  5 FALSE         0.940     0.0595       0.940
#>  6 FALSE         0.75      0.25         0.75 
#>  7 FALSE         0.940     0.0595       0.940
#>  8 FALSE         0.940     0.0595       0.940
#>  9 FALSE         0.940     0.0595       0.940
#> 10 FALSE         0.940     0.0595       0.940
#> # ℹ 490 more rows


# Tuning A Model -------------------------------------------------------
 
# tune the parameters of a model (will take some time)
tuning <- model2 |> 
  tune_parameters(k = 5, levels = 3)
#> Assuming tuning analysis for the 3 parameters 'cost_complexity', 'tree_depth', 'min_n'.
#> Use e.g. `cost_complexity = dials::cost_complexity()` to specify tuning for less parameters.
#> 
#> These parameters will be tuned with these values:
#>   - cost_complexity: 0.000000000100, 0.000003162278, 0.100000000000
#>   - tree_depth: 1, 8, 15
#>   - min_n: 2, 21, 40
#> [2024-11-05 19:31:20] Running tuning analysis using a 5-fold cross-validation for 27 combinations...
#> [2024-11-05 19:31:26] Done.
autoplot(tuning)


# tuning analysis by specifying (some) parameters
iris |> 
  ml_xg_boost(Species) |> 
  tune_parameters(mtry = dials::mtry(range = c(1, 3)),
                  trees = dials::trees())
#> 
#> These parameters will be tuned with these values:
#>   - mtry: 1, 2, 3
#>   - trees: 1, 500, 1000, 1500, 2000
#> [2024-11-05 19:31:27] Running tuning analysis using a 10-fold cross-validation for 15 combinations...
#> [2024-11-05 19:31:45] Done.
#> # A tibble: 15 × 10
#>     mtry trees     n .config            accuracy brier_class roc_auc accuracy_se
#>  * <int> <int> <int> <chr>                 <dbl>       <dbl>   <dbl>       <dbl>
#>  1     2     1    10 Preprocessor1_Mod…    0.955      0.220    0.975      0.0202
#>  2     3     1    10 Preprocessor1_Mod…    0.955      0.219    0.972      0.0202
#>  3     3  1500    10 Preprocessor1_Mod…    0.911      0.0773   0.971      0.0192
#>  4     3  2000    10 Preprocessor1_Mod…    0.911      0.0779   0.971      0.0192
#>  5     3   500    10 Preprocessor1_Mod…    0.911      0.0733   0.970      0.0192
#>  6     3  1000    10 Preprocessor1_Mod…    0.911      0.0762   0.970      0.0192
#>  7     2   500    10 Preprocessor1_Mod…    0.911      0.0741   0.969      0.0192
#>  8     2  1000    10 Preprocessor1_Mod…    0.911      0.0768   0.967      0.0192
#>  9     2  1500    10 Preprocessor1_Mod…    0.902      0.0784   0.967      0.0164
#> 10     2  2000    10 Preprocessor1_Mod…    0.902      0.0795   0.967      0.0164
#> 11     1  1000    10 Preprocessor1_Mod…    0.892      0.0847   0.965      0.0184
#> 12     1   500    10 Preprocessor1_Mod…    0.892      0.0812   0.963      0.0184
#> 13     1  1500    10 Preprocessor1_Mod…    0.892      0.0875   0.960      0.0184
#> 14     1  2000    10 Preprocessor1_Mod…    0.892      0.0888   0.956      0.0184
#> 15     1     1    10 Preprocessor1_Mod…    0.876      0.241    0.954      0.0333
#> # ℹ 2 more variables: brier_class_se <dbl>, roc_auc_se <dbl>


# Practical Example #1 --------------------------------------------------

# this is what iris data set looks like:
head(iris)
#>   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> 1          5.1         3.5          1.4         0.2  setosa
#> 2          4.9         3.0          1.4         0.2  setosa
#> 3          4.7         3.2          1.3         0.2  setosa
#> 4          4.6         3.1          1.5         0.2  setosa
#> 5          5.0         3.6          1.4         0.2  setosa
#> 6          5.4         3.9          1.7         0.4  setosa
# create a model to predict the species:
iris_model <- iris |> ml_xg_boost(Species)
iris_model_rf <- iris |> ml_random_forest(Species)
# is it a bit reliable?
get_metrics(iris_model)
#>    .metric .estimator .estimate
#> 1 accuracy multiclass 0.9473684
#> 2      kap multiclass 0.9203354

# now try to predict species from an arbitrary data set:
to_predict <- data.frame(Sepal.Length = 5,
                         Sepal.Width = 3,
                         Petal.Length = 1.5,
                         Petal.Width = 0.5)
to_predict
#>   Sepal.Length Sepal.Width Petal.Length Petal.Width
#> 1            5           3          1.5         0.5

# should be 'setosa' in the 'predicted' column with huge certainty:
iris_model |> apply_model_to(to_predict)
#> # A tibble: 1 × 5
#>   predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#>   <fct>         <dbl>        <dbl>            <dbl>           <dbl>
#> 1 setosa        0.980        0.980           0.0110         0.00865

# which variables are generally important (only trained variables)?
iris_model |> feature_importances()
#> # A tibble: 3 × 5
#>   feature        gain cover frequency importance
#> * <chr>         <dbl> <dbl>     <dbl>      <dbl>
#> 1 Petal.Width  0.945  0.695     0.466     0.800 
#> 2 Sepal.Length 0.0236 0.164     0.269     0.101 
#> 3 Sepal.Width  0.0310 0.141     0.264     0.0996

# how would the model do without the 'Sepal.Length' column?
to_predict <- to_predict[, c("Sepal.Width", "Petal.Width", "Petal.Length")]
to_predict
#>   Sepal.Width Petal.Width Petal.Length
#> 1           3         0.5          1.5
iris_model |> apply_model_to(to_predict)
#> # A tibble: 1 × 5
#>   predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#>   <fct>         <dbl>        <dbl>            <dbl>           <dbl>
#> 1 setosa        0.979        0.979           0.0125         0.00864

# now compare that with a random forest model that requires imputation:
iris_model_rf |> apply_model_to(to_predict)
#> Adding missing variable as median (= 5.8): Sepal.Length
#> # A tibble: 1 × 5
#>   predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#>   <fct>         <dbl>        <dbl>            <dbl>           <dbl>
#> 1 setosa        0.504        0.504            0.460          0.0353

# the certainly is very different.


# Practical Example #2 -------------------------------------------------

# this example shows plotting methods for a model

# train model to predict genus based on MICs:
genus <- esbl_tests |> ml_xg_boost(genus, everything())
genus |> get_metrics()
#>    .metric .estimator .estimate
#> 1 accuracy multiclass 0.9040000
#> 2      kap multiclass 0.8149747
genus |> feature_importance_plot()

genus |> autoplot()

genus |> autoplot(plot_type = "gain")

genus |> autoplot(plot_type = "pr")