Metric types
There are three main metric types in yardstick
: class,
class probability, and numeric. Each type of metric has standardized
argument syntax, and all metrics return the same kind of output (a
tibble with 3 columns). This standardization allows metrics to easily be
grouped together and used with grouped data frames for computing on
multiple resamples at once. Below are the five types of metrics, along
with the types of the inputs they take.
-
Class metrics (hard predictions)
truth
- factorestimate
- factor
-
Class probability metrics (soft predictions)
truth
- factorestimate / ...
- multiple numeric columns containing class probabilities
-
Numeric metrics
truth
- numericestimate
- numeric
-
Static survival metircs
truth
- Survestimate
- numeric
-
dynamic survival metrics
truth
- Surv...
- list of data.frames, each containing the 3 columns.eval_time
,.pred_survival, and
.weight_censored`
Example
In the following example, the hpc_cv
data set is used.
It contains class probabilities and class predictions for a linear
discriminant analysis fit to the HPC data set of Kuhn and Johnson
(2013). It is fit with 10 fold cross-validation, and the predictions for
all folds are included.
library(yardstick)
library(dplyr)
data("hpc_cv")
hpc_cv %>%
group_by(Resample) %>%
slice(1:3)
#> # A tibble: 30 × 7
#> # Groups: Resample [10]
#> obs pred VF F M L Resample
#> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr>
#> 1 VF VF 0.914 0.0779 0.00848 0.0000199 Fold01
#> 2 VF VF 0.938 0.0571 0.00482 0.0000101 Fold01
#> 3 VF VF 0.947 0.0495 0.00316 0.00000500 Fold01
#> 4 VF VF 0.941 0.0544 0.00441 0.0000123 Fold02
#> 5 VF VF 0.948 0.0483 0.00347 0.00000792 Fold02
#> 6 VF VF 0.958 0.0395 0.00236 0.00000310 Fold02
#> 7 VF VF 0.939 0.0556 0.00513 0.00000790 Fold03
#> 8 VF VF 0.928 0.0642 0.00777 0.0000148 Fold03
#> 9 VF VF 0.927 0.0653 0.00786 0.0000150 Fold03
#> 10 VF VF 0.949 0.0469 0.00398 0.00000935 Fold04
#> # ℹ 20 more rows
1 metric, 1 resample
hpc_cv %>%
filter(Resample == "Fold01") %>%
accuracy(obs, pred)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy multiclass 0.726
1 metric, 10 resamples
hpc_cv %>%
group_by(Resample) %>%
accuracy(obs, pred)
#> # A tibble: 10 × 4
#> Resample .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 Fold01 accuracy multiclass 0.726
#> 2 Fold02 accuracy multiclass 0.712
#> 3 Fold03 accuracy multiclass 0.758
#> 4 Fold04 accuracy multiclass 0.712
#> 5 Fold05 accuracy multiclass 0.712
#> 6 Fold06 accuracy multiclass 0.697
#> 7 Fold07 accuracy multiclass 0.675
#> 8 Fold08 accuracy multiclass 0.721
#> 9 Fold09 accuracy multiclass 0.673
#> 10 Fold10 accuracy multiclass 0.699
2 metrics, 10 resamples
class_metrics <- metric_set(accuracy, kap)
hpc_cv %>%
group_by(Resample) %>%
class_metrics(obs, estimate = pred)
#> # A tibble: 20 × 4
#> Resample .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 Fold01 accuracy multiclass 0.726
#> 2 Fold02 accuracy multiclass 0.712
#> 3 Fold03 accuracy multiclass 0.758
#> 4 Fold04 accuracy multiclass 0.712
#> 5 Fold05 accuracy multiclass 0.712
#> 6 Fold06 accuracy multiclass 0.697
#> 7 Fold07 accuracy multiclass 0.675
#> 8 Fold08 accuracy multiclass 0.721
#> 9 Fold09 accuracy multiclass 0.673
#> 10 Fold10 accuracy multiclass 0.699
#> 11 Fold01 kap multiclass 0.533
#> 12 Fold02 kap multiclass 0.512
#> 13 Fold03 kap multiclass 0.594
#> 14 Fold04 kap multiclass 0.511
#> 15 Fold05 kap multiclass 0.514
#> 16 Fold06 kap multiclass 0.486
#> 17 Fold07 kap multiclass 0.454
#> 18 Fold08 kap multiclass 0.531
#> 19 Fold09 kap multiclass 0.454
#> 20 Fold10 kap multiclass 0.492
Metrics
Below is a table of all of the metrics available in
yardstick
, grouped by type.
type | metric |
---|---|
class |
accuracy()
|
class |
bal_accuracy()
|
class |
detection_prevalence()
|
class |
f_meas()
|
class |
j_index()
|
class |
kap()
|
class |
mcc()
|
class |
npv()
|
class |
ppv()
|
class |
precision()
|
class |
recall()
|
class |
sens()
|
class |
sensitivity()
|
class |
spec()
|
class |
specificity()
|
class prob |
average_precision()
|
class prob |
brier_class()
|
class prob |
classification_cost()
|
class prob |
gain_capture()
|
class prob |
mn_log_loss()
|
class prob |
pr_auc()
|
class prob |
roc_auc()
|
class prob |
roc_aunp()
|
class prob |
roc_aunu()
|
numeric |
ccc()
|
numeric |
huber_loss()
|
numeric |
huber_loss_pseudo()
|
numeric |
iic()
|
numeric |
mae()
|
numeric |
mape()
|
numeric |
mase()
|
numeric |
mpe()
|
numeric |
msd()
|
numeric |
poisson_log_loss()
|
numeric |
rmse()
|
numeric |
rpd()
|
numeric |
rpiq()
|
numeric |
rsq()
|
numeric |
rsq_trad()
|
numeric |
smape()
|
dynamic survival |
brier_survival()
|
dynamic survival |
roc_auc_survival()
|
static survival |
concordance_survival()
|