The AUC (Area under the curve) of the ROC (Receiver operating characteristic; default) or PR (Precision Recall) curves are quality measures of binary classifiers. Unlike the accuracy, and like cross-entropy losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.
This class approximates AUCs using a Riemann sum. During the metric accumulation phrase, predictions are accumulated within predefined buckets by value. The AUC is then computed by interpolating per-bucket averages. These buckets define the evaluated operational points.
This metric creates four local variables, true_positives
,
true_negatives
, false_positives
and false_negatives
that are used to
compute the AUC. To discretize the AUC curve, a linearly spaced set of
thresholds is used to compute pairs of recall and precision values. The area
under the ROC-curve is therefore computed using the height of the recall
values by the false positive rate, while the area under the PR-curve is the
computed using the height of the precision values by the recall.
This value is ultimately returned as auc
, an idempotent operation that
computes the area under a discretized curve of precision versus recall
values (computed using the aforementioned variables). The num_thresholds
variable controls the degree of discretization with larger numbers of
thresholds more closely approximating the true AUC. The quality of the
approximation may vary dramatically depending on num_thresholds
. The
thresholds
parameter can be used to manually specify thresholds which
split the predictions more evenly.
For a best approximation of the real AUC, predictions
should be
distributed approximately uniformly in the range [0, 1]
(if
from_logits=FALSE
). The quality of the AUC approximation may be poor if
this is not the case. Setting summation_method
to 'minoring' or 'majoring'
can help quantify the error in the approximation by providing lower or upper
bound estimate of the AUC.
If sample_weight
is NULL
, weights default to 1.
Use sample_weight
of 0 to mask values.
metric_auc(
...,
num_thresholds = 200L,
curve = "ROC",
summation_method = "interpolation",
name = NULL,
dtype = NULL,
thresholds = NULL,
multi_label = FALSE,
num_labels = NULL,
label_weights = NULL,
from_logits = FALSE
)
a Metric
instance is returned. The Metric
instance can be passed
directly to compile(metrics = )
, or used as a standalone object. See
?Metric
for example usage.
For forward/backward compatability.
(Optional) The number of thresholds to
use when discretizing the roc curve. Values must be > 1.
Defaults to 200
.
(Optional) Specifies the name of the curve to be computed,
'ROC'
(default) or 'PR'
for the Precision-Recall-curve.
(Optional) Specifies the Riemann summation method used.
'interpolation' (default) applies mid-point summation scheme for
ROC
. For PR-AUC, interpolates (true/false) positives but not
the ratio that is precision (see Davis & Goadrich 2006 for
details); 'minoring' applies left summation for increasing
intervals and right summation for decreasing intervals; 'majoring'
does the opposite.
(Optional) string name of the metric instance.
(Optional) data type of the metric result.
(Optional) A list of floating point values to use as the
thresholds for discretizing the curve. If set, the num_thresholds
parameter is ignored. Values should be in [0, 1]
. Endpoint
thresholds equal to {-epsilon
, 1+epsilon
} for a small positive
epsilon value will be automatically included with these to correctly
handle predictions equal to exactly 0 or 1.
boolean indicating whether multilabel data should be
treated as such, wherein AUC is computed separately for each label
and then averaged across labels, or (when FALSE
) if the data
should be flattened into a single label before AUC computation. In
the latter case, when multilabel data is passed to AUC, each
label-prediction pair is treated as an individual data point. Should
be set to `FALSE`` for multi-class data.
(Optional) The number of labels, used when multi_label
is
TRUE. If num_labels
is not specified, then state variables get
created on the first call to update_state
.
(Optional) list, array, or tensor of non-negative weights
used to compute AUCs for multilabel data. When multi_label
is
TRUE, the weights are applied to the individual label AUCs when they
are averaged to produce the multi-label AUC. When it's FALSE, they
are used to weight the individual label predictions in computing the
confusion matrix on the flattened data. Note that this is unlike
class_weights
in that class_weights
weights the example
depending on the value of its label, whereas label_weights
depends
only on the index of that label before flattening; therefore
label_weights
should not be used for multi-class data.
boolean indicating whether the predictions (y_pred
in
update_state
) are probabilities or sigmoid logits. As a rule of thumb,
when using a keras loss, the from_logits
constructor argument of the
loss should match the AUC from_logits
constructor argument.
Standalone usage:
m <- metric_auc(num_thresholds = 3)
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9))
# threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
# tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
# tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]
# auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))
# = 0.75
m$result()
## tf.Tensor(0.75, shape=(), dtype=float32)
m$reset_state()
m$update_state(c(0, 0, 1, 1),
c(0, 0.5, 0.3, 0.9),
sample_weight=c(1, 0, 0, 1))
m$result()
## tf.Tensor(1.0, shape=(), dtype=float32)
Usage with compile()
API:
# Reports the AUC of a model outputting a probability.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(),
metrics = list(metric_auc())
)# Reports the AUC of a model outputting a logit.
model |> compile(
optimizer = 'sgd',
loss = loss_binary_crossentropy(from_logits = TRUE),
metrics = list(metric_auc(from_logits = TRUE))
)
Other confusion metrics:
metric_false_negatives()
metric_false_positives()
metric_precision()
metric_precision_at_recall()
metric_recall()
metric_recall_at_precision()
metric_sensitivity_at_specificity()
metric_specificity_at_sensitivity()
metric_true_negatives()
metric_true_positives()
Other metrics:
Metric()
custom_metric()
metric_binary_accuracy()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_binary_iou()
metric_categorical_accuracy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_concordance_correlation()
metric_cosine_similarity()
metric_f1_score()
metric_false_negatives()
metric_false_positives()
metric_fbeta_score()
metric_hinge()
metric_huber()
metric_iou()
metric_kl_divergence()
metric_log_cosh()
metric_log_cosh_error()
metric_mean()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_iou()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_mean_wrapper()
metric_one_hot_iou()
metric_one_hot_mean_iou()
metric_pearson_correlation()
metric_poisson()
metric_precision()
metric_precision_at_recall()
metric_r2_score()
metric_recall()
metric_recall_at_precision()
metric_root_mean_squared_error()
metric_sensitivity_at_specificity()
metric_sparse_categorical_accuracy()
metric_sparse_categorical_crossentropy()
metric_sparse_top_k_categorical_accuracy()
metric_specificity_at_sensitivity()
metric_squared_hinge()
metric_sum()
metric_top_k_categorical_accuracy()
metric_true_negatives()
metric_true_positives()