These functions can be used to create a machine learning model based on different 'engines' and to generalise predicting outcomes based on such models. These functions are wrappers around tidymodels
packages (especially parsnip
, recipes
, rsample
, tune
, and yardstick
) created by RStudio.
ml_xg_boost(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "xgboost",
mode = c("classification", "regression", "unknown"),
trees = 15,
...
)
ml_decision_trees(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "rpart",
mode = c("classification", "regression", "unknown"),
tree_depth = 30,
...
)
ml_random_forest(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "ranger",
mode = c("classification", "regression", "unknown"),
trees = 500,
...
)
ml_neural_network(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "nnet",
mode = c("classification", "regression", "unknown"),
penalty = 0,
epochs = 100,
...
)
ml_nearest_neighbour(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "kknn",
mode = c("classification", "regression", "unknown"),
neighbors = 5,
weight_func = "triangular",
...
)
ml_linear_regression(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "lm",
mode = "regression",
...
)
ml_logistic_regression(
.data,
outcome,
predictors = everything(),
training_fraction = 0.75,
strata = NULL,
na_threshold = 0.01,
correlation_threshold = 0.9,
centre = TRUE,
scale = TRUE,
engine = "glm",
mode = "classification",
penalty = 0.1,
...
)
# S3 method for class 'certestats_ml'
confusion_matrix(data, ...)
# S3 method for class 'certestats_ml'
predict(object, new_data, type = NULL, ...)
apply_model_to(
object,
new_data,
add_certainty = TRUE,
only_prediction = FALSE,
correct_mistakes = TRUE,
impute_algorithm = "mice",
...
)
feature_importances(object, ...)
feature_importance_plot(object, ...)
roc_plot(object, ...)
gain_plot(object, ...)
tree_plot(object, ...)
correlation_plot(
data,
add_values = TRUE,
cols = everything(),
correlation_threshold = 0.9
)
get_metrics(object)
get_accuracy(object)
get_kappa(object)
get_recipe(object)
get_specification(object)
get_rows_testing(object)
get_rows_training(object)
get_original_data(object)
get_roc_data(object)
get_coefficients(object)
get_model_variables(object)
get_variable_weights(object)
tune_parameters(object, ..., only_params_in_model = FALSE, levels = 5, k = 10)
check_testing_predictions(object)
# S3 method for class 'certestats_ml'
autoplot(object, plot_type = "roc", ...)
# S3 method for class 'certestats_feature_importances'
autoplot(object, ...)
# S3 method for class 'certestats_tuning'
autoplot(object, type = c("marginals", "parameters", "performance"), ...)
Data set to train
Outcome variable, also called the response variable or the dependent variable; the variable that must be predicted. The value will be evaluated in select()
and thus supports the tidyselect
language. In case of classification prediction, this variable will be coerced to a factor.
Explanatory variables, also called the predictors or the independent variables; the variables that are used to predict outcome
. These variables will be transformed using as.double()
(factors will be transformed to characters first). This value defaults to everything()
and supports the tidyselect
language.
Fraction of rows to be used for training, defaults to 75%. The rest will be used for testing. If given a number over 1, the number will be considered to be the required number of rows for training.
A variable in data
(single character or name) used to conduct
stratified sampling. When not NULL
, each resample is created within the
stratification variable. Numeric strata
are binned into quartiles.
Maximum fraction of NA
values (defaults to 0.01
) of the predictors
before they are removed from the model, using recipes::step_select()
A value (default 0.9) to indicate the correlation threshold. Predictors with a correlation higher than this value with be removed from the model, using recipes::step_corr()
A logical to indicate whether the predictors
should be transformed so that their mean will be 0
, using recipes::step_center()
. Binary columns will be skipped.
A logical to indicate whether the predictors
should be transformed so that their standard deviation will be 1
, using recipes::step_scale()
. Binary columns will be skipped.
R package or function name to be used for the model, will be passed on to parsnip::set_engine()
Type of predicted value - defaults to "classification"
, but can also be "unknown"
or "regression"
An integer for the number of trees contained in the ensemble.
Arguments to be passed on to the parsnip
functions, see Model Functions.
For the tune_parameters()
function, these must be dials
package calls, such as dials::trees()
(see Examples).
For predict()
, these must be arguments passed on to parsnip::predict.model_fit()
An integer for maximum depth of the tree.
A non-negative number representing the total amount of regularization (specific engines only).
An integer for the number of training iterations.
A single integer for the number of neighbors
to consider (often called k
). For kknn, a value of 5
is used if neighbors
is not specified.
A single character for the type of kernel function used
to weight distances between samples. Valid choices are: "rectangular"
,
"triangular"
, "epanechnikov"
, "biweight"
, "triweight"
,
"cos"
, "inv"
, "gaussian"
, "rank"
, or "optimal"
.
outcome of machine learning model
A rectangular data object, such as a data frame.
A single character value or NULL
. Possible values
are "numeric"
, "class"
, "prob"
, "conf_int"
, "pred_int"
,
"quantile"
, "time"
, "hazard"
, "survival"
, or "raw"
. When NULL
,
predict()
will choose an appropriate value based on the model's mode.
a logical to indicate whether certainties should be added to the output data.frame
a logical to indicate whether predictions must be returned as vector, otherwise returns a data.frame
a logical to indicate whether missing variables and missing values should be added to new_data
the algorithm to use in impute()
if correct_mistakes = TRUE
. Can be "mice"
(default) for the Multivariate Imputations by Chained Equations (MICE) algorithm, or "single-point"
for a trained median.
a logical to indicate whether values must be printed in the tiles
columns to use for correlation plot, defaults to everything()
a logical to indicate whether only parameters in the model should be tuned
An integer for the number of values of each parameter to use
to make the regular grid. levels
can be a single integer or a vector of
integers that is the same length as the number of parameters in ...
.
levels
can be a named integer vector, with names that match the id values
of parameters.
The number of partitions of the data set
the plot type, can be "roc"
(default), "gain"
, "lift"
or "pr"
. These functions rely on yardstick::roc_curve()
, yardstick::gain_curve()
, yardstick::lift_curve()
and yardstick::pr_curve()
to construct the curves.
A machine learning model of class certestats_ml
/ ... / model_fit
.
To predict regression (numeric values), the function ml_logistic_regression()
cannot be used.
To predict classifications (character values), the function ml_linear_regression()
cannot be used.
The workflow of the ml_*()
functions is basically like this (thus saving a lot of tidymodels
functions to type):
.data
|
rsample::initial_split()
/ \
rsample::training() rsample::testing()
| |
recipe::recipe() |
| |
recipe::step_corr() |
| |
recipe::step_center() |
| |
recipe::step_scale() |
| |
recipe::prep() |
/ \ |
recipes::bake() recipes::bake()
| |
generics::fit() yardstick::metrics()
| |
output attributes(output)
The predict()
function can be used to fit a model on a new data set. Its wrapper apply_model_to()
works in the same way, but can also detect and fix missing variables, missing data points, and data type differences between the trained data and the input data.
Use feature_importances()
to get the importance of all features/variables. Use autoplot()
afterwards to plot the results. These two functions are combined in feature_importance_plot()
.
Use correlation_plot()
to plot the correlation between all variables, even characters. If the input is a certestats
ML model, the training data of the model will be plotted.
Use the get_model_variables()
function to return a zero-row data.frame with the variables that were used for training, even before the recipe steps.
Use the get_variable_weights()
function to determine the (rough) estimated weights of each variable in the model. This is not as reliable as retrieving coefficients, but it does work for any model. The weights are determined by running the model over all the highest and lowest values of each variable in the trained data. The function returns a data set with 1 row, of which the values sum up to 1.
Use the tune_parameters()
function to analyse tune parameters of any ml_*()
function. Without any parameters manually defined, it will try to tune all parameters of the underlying ML model. The tuning will be based on a K-fold cross-validation, of which the number of partitions can be set with k
. The number of levels
will be used to split the range of the parameters. For example, a range of 1-10 with levels = 2
will lead to [1, 10]
, while levels = 5
will lead to [1, 3, 5, 7, 9]
. The resulting data.frame will be sorted from best to worst. These results can also be plotted using autoplot()
.
The check_testing_predictions()
function combines the data used for testing from the original data with its predictions, so the original data can be reviewed per prediction.
Use autoplot()
on a model to plot the receiver operating characteristic (ROC) curve, the gain curve, the lift curve, or the precision-recall (PR) curve. For the ROC curve, the (overall) area under the curve (AUC) will be printed as subtitle.
The ml_*()
functions return the following attributes:
properties
: a list with model properties: the ML function, engine package, training size, testing size, strata size, mode, and the different ML function-specific properties (such as tree_depth
in ml_decision_trees()
)
recipe
: a recipe as generated with recipes::prep()
, to be used for training and testing
data_original
: a data.frame containing the original data, possibly without invalid strata
data_structure
: a data.frame containing the original data structure (only trained variables) with zero rows
data_means
: a data.frame containing the means of the original data (only trained variables)
data_training
: a data.frame containing the training data of data_original
data_testing
: a data.frame containing the testing data of data_original
rows_training
: an integer vector of rows used for training in data_original
rows_testing
: an integer vector of rows used for training in data_original
predictions
: a data.frame containing predicted values based on the testing data
metrics
: a data.frame with model metrics as returned by yardstick::metrics()
correlation_threshold
: a logical indicating whether recipes::step_corr()
has been applied
centre
: a logical indicating whether recipes::step_center()
has been applied
scale
: a logical indicating whether recipes::step_scale()
has been applied
These are the called functions from the parsnip
package. Arguments set in ...
will be passed on to these parsnip
functions:
ml_decision_trees
: parsnip::decision_tree()
ml_linear_regression
: parsnip::linear_reg()
ml_logistic_regression
: parsnip::logistic_reg()
ml_neural_network
: parsnip::mlp()
ml_nearest_neighbour
: parsnip::nearest_neighbor()
ml_random_forest
: parsnip::rand_forest()
ml_xg_boost
: parsnip::xgb_train()
# 'esbl_tests' is an included data set, see ?esbl_tests
print(esbl_tests, n = 5)
#> # A tibble: 500 × 19
#> esbl genus AMC AMP TZP CXM FOX CTX CAZ GEN TOB TMP SXT
#> <lgl> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 FALSE Esche… 32 32 4 64 64 8 8 1 1 16 20
#> 2 FALSE Esche… 32 32 4 64 64 4 8 1 1 16 320
#> 3 FALSE Esche… 4 2 64 8 4 8 0.12 16 16 0.5 20
#> 4 FALSE Klebs… 32 32 16 64 64 8 8 1 1 0.5 20
#> 5 FALSE Esche… 32 32 4 4 4 0.25 2 1 1 16 320
#> # ℹ 495 more rows
#> # ℹ 6 more variables: NIT <dbl>, FOS <dbl>, CIP <dbl>, IPM <dbl>, MEM <dbl>,
#> # COL <dbl>
esbl_tests |> correlation_plot(add_values = FALSE) # red will be removed from model
# predict ESBL test outcome based on MICs using 2 different models
model1 <- esbl_tests |> ml_xg_boost(esbl, where(is.double))
model2 <- esbl_tests |> ml_decision_trees(esbl, where(is.double))
# Assessing A Model ----------------------------------------------------
model1 |> get_metrics()
#> .metric .estimator .estimate
#> 1 accuracy binary 0.9680000
#> 2 kap binary 0.9359795
model2 |> get_metrics()
#> .metric .estimator .estimate
#> 1 accuracy binary 0.9280000
#> 2 kap binary 0.8558616
model1 |> confusion_matrix()
#>
#> ── Confusion Matrix ────────────────────────────────────────────────────────────
#>
#> TRUE FALSE
#> TRUE 62 1
#> FALSE 3 59
#>
#>
#> ── Model Metrics ───────────────────────────────────────────────────────────────
#>
#> Accuracy 0.968
#> Area under the Precision Recall Curve (APRC) 0.998
#> Area under the Receiver Operator Curve (AROC) 0.998
#> Balanced Accuracy 0.968
#> Brier Score for Classification Models (BSCM) 0.020
#> Concordance Correlation Coefficient (CCC) -0.309
#> Costs Function for Poor Classification (CFPC) 0.065
#> F Measure 0.969
#> Gain Capture 0.996
#> Huber Loss 0.709
#> Index of Ideality of Correlation (IIC) NaN
#> J-Index 0.936
#> Kappa 0.936
#> Matthews Correlation Coefficient (MCC) 0.936
#> Mean Absolute Error (MAE) 0.981
#> Mean Absolute Percent Error (MAPE) 50.377
#> Mean Absolute Scaled Error (MASE) 121.592
#> Mean log Loss for Multinomial Data (MLMD) 0.083
#> Mean log Loss for Poisson Data (MLPD) 4.253
#> Mean Percentage Error (MPE) 50.377
#> Mean Signed Deviation (MSD) 0.981
#> Negative Predictive Value (NPV) 0.983
#> Positive Predictive Value (PPV) 0.954
#> Precision 0.954
#> Prevalence 0.520
#> Psuedo-Huber Loss (PHL) 0.583
#> R Squared 0.924
#> R Squared - Traditional (RST) -6.402
#> Ratio of Performance to Deviation (RPD) 0.369
#> Ratio of Performance to Inter-Quartile (RPIQ) 0.735
#> Recall 0.984
#> Root Mean Squared Error (RMSE) 1.360
#> Sensitivity 0.984
#> Specificity 0.952
#> Symmetric Mean Absolute Percentage Error (SMAPE) 95.769
# a correlation plot of a model shows the training data
model1 |> correlation_plot(add_values = FALSE)
model1 |> feature_importances()
#> # A tibble: 14 × 5
#> feature gain cover frequency importance
#> * <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 CTX 0.666 0.324 0.152 0.495
#> 2 FOX 0.131 0.163 0.137 0.139
#> 3 TZP 0.0236 0.135 0.152 0.0715
#> 4 GEN 0.0502 0.0970 0.0508 0.0597
#> 5 CIP 0.0331 0.0470 0.117 0.0526
#> 6 NIT 0.0259 0.0422 0.0761 0.0392
#> 7 IPM 0.0175 0.0829 0.0558 0.0383
#> 8 CAZ 0.0148 0.0433 0.102 0.0378
#> 9 TMP 0.0250 0.0178 0.0609 0.0308
#> 10 AMC 0.00437 0.0178 0.0305 0.0123
#> 11 FOS 0.00317 0.0133 0.0254 0.00964
#> 12 TOB 0.00141 0.00480 0.0203 0.00587
#> 13 CXM 0.00267 0.00598 0.0152 0.00584
#> 14 COL 0.00131 0.00650 0.00508 0.00310
model1 |> feature_importances() |> autoplot()
model2 |> feature_importance_plot()
# decision trees can also have a tree plot
model2 |> tree_plot()
# Applying A Model -----------------------------------------------------
# simply use base R `predict()` to apply a model:
model1 |> predict(esbl_tests)
#> # A tibble: 500 × 1
#> .pred_class
#> <fct>
#> 1 FALSE
#> 2 FALSE
#> 3 FALSE
#> 4 FALSE
#> 5 FALSE
#> 6 FALSE
#> 7 FALSE
#> 8 FALSE
#> 9 FALSE
#> 10 FALSE
#> # ℹ 490 more rows
# but apply_model_to() contains more info and can apply corrections:
model1 |> apply_model_to(esbl_tests)
#> # A tibble: 500 × 4
#> predicted certainty .pred_TRUE .pred_FALSE
#> <lgl> <dbl> <dbl> <dbl>
#> 1 FALSE 0.990 0.00958 0.990
#> 2 FALSE 0.990 0.00958 0.990
#> 3 FALSE 0.655 0.345 0.655
#> 4 FALSE 0.984 0.0157 0.984
#> 5 FALSE 0.986 0.0141 0.986
#> 6 FALSE 0.924 0.0759 0.924
#> 7 FALSE 0.982 0.0185 0.982
#> 8 FALSE 0.985 0.0150 0.985
#> 9 FALSE 0.987 0.0133 0.987
#> 10 FALSE 0.946 0.0539 0.946
#> # ℹ 490 more rows
model1 |> apply_model_to(esbl_tests[, 1:15])
#> # A tibble: 500 × 4
#> predicted certainty .pred_TRUE .pred_FALSE
#> <lgl> <dbl> <dbl> <dbl>
#> 1 FALSE 0.990 0.00958 0.990
#> 2 FALSE 0.990 0.00958 0.990
#> 3 TRUE 0.676 0.676 0.324
#> 4 FALSE 0.984 0.0157 0.984
#> 5 FALSE 0.982 0.0177 0.982
#> 6 FALSE 0.924 0.0759 0.924
#> 7 FALSE 0.978 0.0215 0.978
#> 8 FALSE 0.649 0.351 0.649
#> 9 FALSE 0.987 0.0133 0.987
#> 10 FALSE 0.983 0.0174 0.983
#> # ℹ 490 more rows
esbl_tests2 <- esbl_tests
esbl_tests2[2, "CIP"] <- NA
esbl_tests2[5, "AMC"] <- NA
# with XGBoost, nothing will be changed (it can correct for missings):
model1 |> apply_model_to(esbl_tests2)
#> # A tibble: 500 × 4
#> predicted certainty .pred_TRUE .pred_FALSE
#> <lgl> <dbl> <dbl> <dbl>
#> 1 FALSE 0.990 0.00958 0.990
#> 2 FALSE 0.990 0.00958 0.990
#> 3 FALSE 0.655 0.345 0.655
#> 4 FALSE 0.984 0.0157 0.984
#> 5 FALSE 0.973 0.0275 0.973
#> 6 FALSE 0.924 0.0759 0.924
#> 7 FALSE 0.982 0.0185 0.982
#> 8 FALSE 0.985 0.0150 0.985
#> 9 FALSE 0.987 0.0133 0.987
#> 10 FALSE 0.946 0.0539 0.946
#> # ℹ 490 more rows
# with random forest (or others), missings will be imputed:
model2 |> apply_model_to(esbl_tests2)
#> Generating MICE using m = 5 multiple imputations...
#> OK.
#> Imputed variable 'AMC' using MICE (method: predictive mean matching) in row 5
#> Imputed variable 'CIP' using MICE (method: predictive mean matching) in row 2
#> # A tibble: 500 × 4
#> predicted certainty .pred_TRUE .pred_FALSE
#> <lgl> <dbl> <dbl> <dbl>
#> 1 FALSE 0.940 0.0595 0.940
#> 2 FALSE 0.940 0.0595 0.940
#> 3 TRUE 0.769 0.769 0.231
#> 4 FALSE 0.940 0.0595 0.940
#> 5 FALSE 0.940 0.0595 0.940
#> 6 FALSE 0.75 0.25 0.75
#> 7 FALSE 0.940 0.0595 0.940
#> 8 FALSE 0.940 0.0595 0.940
#> 9 FALSE 0.940 0.0595 0.940
#> 10 FALSE 0.940 0.0595 0.940
#> # ℹ 490 more rows
# Tuning A Model -------------------------------------------------------
# tune the parameters of a model (will take some time)
tuning <- model2 |>
tune_parameters(k = 5, levels = 3)
#> Assuming tuning analysis for the 3 parameters 'cost_complexity', 'tree_depth', 'min_n'.
#> Use e.g. `cost_complexity = dials::cost_complexity()` to specify tuning for less parameters.
#>
#> These parameters will be tuned with these values:
#> - cost_complexity: 0.000000000100, 0.000003162278, 0.100000000000
#> - tree_depth: 1, 8, 15
#> - min_n: 2, 21, 40
#> [2024-11-05 19:31:20] Running tuning analysis using a 5-fold cross-validation for 27 combinations...
#> [2024-11-05 19:31:26] Done.
autoplot(tuning)
# tuning analysis by specifying (some) parameters
iris |>
ml_xg_boost(Species) |>
tune_parameters(mtry = dials::mtry(range = c(1, 3)),
trees = dials::trees())
#>
#> These parameters will be tuned with these values:
#> - mtry: 1, 2, 3
#> - trees: 1, 500, 1000, 1500, 2000
#> [2024-11-05 19:31:27] Running tuning analysis using a 10-fold cross-validation for 15 combinations...
#> [2024-11-05 19:31:45] Done.
#> # A tibble: 15 × 10
#> mtry trees n .config accuracy brier_class roc_auc accuracy_se
#> * <int> <int> <int> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 2 1 10 Preprocessor1_Mod… 0.955 0.220 0.975 0.0202
#> 2 3 1 10 Preprocessor1_Mod… 0.955 0.219 0.972 0.0202
#> 3 3 1500 10 Preprocessor1_Mod… 0.911 0.0773 0.971 0.0192
#> 4 3 2000 10 Preprocessor1_Mod… 0.911 0.0779 0.971 0.0192
#> 5 3 500 10 Preprocessor1_Mod… 0.911 0.0733 0.970 0.0192
#> 6 3 1000 10 Preprocessor1_Mod… 0.911 0.0762 0.970 0.0192
#> 7 2 500 10 Preprocessor1_Mod… 0.911 0.0741 0.969 0.0192
#> 8 2 1000 10 Preprocessor1_Mod… 0.911 0.0768 0.967 0.0192
#> 9 2 1500 10 Preprocessor1_Mod… 0.902 0.0784 0.967 0.0164
#> 10 2 2000 10 Preprocessor1_Mod… 0.902 0.0795 0.967 0.0164
#> 11 1 1000 10 Preprocessor1_Mod… 0.892 0.0847 0.965 0.0184
#> 12 1 500 10 Preprocessor1_Mod… 0.892 0.0812 0.963 0.0184
#> 13 1 1500 10 Preprocessor1_Mod… 0.892 0.0875 0.960 0.0184
#> 14 1 2000 10 Preprocessor1_Mod… 0.892 0.0888 0.956 0.0184
#> 15 1 1 10 Preprocessor1_Mod… 0.876 0.241 0.954 0.0333
#> # ℹ 2 more variables: brier_class_se <dbl>, roc_auc_se <dbl>
# Practical Example #1 --------------------------------------------------
# this is what iris data set looks like:
head(iris)
#> Sepal.Length Sepal.Width Petal.Length Petal.Width Species
#> 1 5.1 3.5 1.4 0.2 setosa
#> 2 4.9 3.0 1.4 0.2 setosa
#> 3 4.7 3.2 1.3 0.2 setosa
#> 4 4.6 3.1 1.5 0.2 setosa
#> 5 5.0 3.6 1.4 0.2 setosa
#> 6 5.4 3.9 1.7 0.4 setosa
# create a model to predict the species:
iris_model <- iris |> ml_xg_boost(Species)
iris_model_rf <- iris |> ml_random_forest(Species)
# is it a bit reliable?
get_metrics(iris_model)
#> .metric .estimator .estimate
#> 1 accuracy multiclass 0.9473684
#> 2 kap multiclass 0.9203354
# now try to predict species from an arbitrary data set:
to_predict <- data.frame(Sepal.Length = 5,
Sepal.Width = 3,
Petal.Length = 1.5,
Petal.Width = 0.5)
to_predict
#> Sepal.Length Sepal.Width Petal.Length Petal.Width
#> 1 5 3 1.5 0.5
# should be 'setosa' in the 'predicted' column with huge certainty:
iris_model |> apply_model_to(to_predict)
#> # A tibble: 1 × 5
#> predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 setosa 0.980 0.980 0.0110 0.00865
# which variables are generally important (only trained variables)?
iris_model |> feature_importances()
#> # A tibble: 3 × 5
#> feature gain cover frequency importance
#> * <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Petal.Width 0.945 0.695 0.466 0.800
#> 2 Sepal.Length 0.0236 0.164 0.269 0.101
#> 3 Sepal.Width 0.0310 0.141 0.264 0.0996
# how would the model do without the 'Sepal.Length' column?
to_predict <- to_predict[, c("Sepal.Width", "Petal.Width", "Petal.Length")]
to_predict
#> Sepal.Width Petal.Width Petal.Length
#> 1 3 0.5 1.5
iris_model |> apply_model_to(to_predict)
#> # A tibble: 1 × 5
#> predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 setosa 0.979 0.979 0.0125 0.00864
# now compare that with a random forest model that requires imputation:
iris_model_rf |> apply_model_to(to_predict)
#> Adding missing variable as median (= 5.8): Sepal.Length
#> # A tibble: 1 × 5
#> predicted certainty .pred_setosa .pred_versicolor .pred_virginica
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 setosa 0.504 0.504 0.460 0.0353
# the certainly is very different.
# Practical Example #2 -------------------------------------------------
# this example shows plotting methods for a model
# train model to predict genus based on MICs:
genus <- esbl_tests |> ml_xg_boost(genus, everything())
genus |> get_metrics()
#> .metric .estimator .estimate
#> 1 accuracy multiclass 0.9040000
#> 2 kap multiclass 0.8149747
genus |> feature_importance_plot()
genus |> autoplot()
genus |> autoplot(plot_type = "gain")
genus |> autoplot(plot_type = "pr")