Loss
classUse this to define a custom loss class. Note, in most cases you do not need
to subclass Loss
to define a custom loss: you can also pass a bare R
function, or a named R function defined with custom_metric()
, as a loss
function to compile()
.
Loss(
classname,
call = NULL,
...,
public = list(),
private = list(),
inherit = NULL,
parent_env = parent.frame()
)
A function that returns Loss
instances, similar to the
builtin loss functions.
String, the name of the custom class. (Conventionally, CamelCase).
function(y_true, y_pred)
Method to be implemented by subclasses:
Function that contains the logic for loss calculation using
y_true
, y_pred
.
Additional methods or public members of the custom class.
Named list of R objects (typically, functions) to include in
instance private environments. private
methods will have all the same
symbols in scope as public methods (See section "Symbols in Scope"). Each
instance will have it's own private
environment. Any objects
in private
will be invisible from the Keras framework and the Python
runtime.
What the custom class will subclass. By default, the base keras class.
The R environment that all class methods will have as a grandparent.
initialize(name=NULL, reduction="sum_over_batch_size", dtype=NULL)
Args:
name
: Optional name for the loss instance.
reduction
: Type of reduction to apply to the loss. In almost all cases
this should be "sum_over_batch_size"
. Supported options are
"sum"
, "sum_over_batch_size"
, "mean"
,
"mean_with_sample_weight"
or NULL
. "sum"
sums the loss,
"sum_over_batch_size"
and "mean"
sum the loss and divide by the
sample size, and "mean_with_sample_weight"
sums the loss and
divides by the sum of the sample weights. "none"
and NULL
perform no aggregation. Defaults to "sum_over_batch_size"
.
dtype
: The dtype of the loss's computations. Defaults to NULL
, which
means using config_floatx()
. config_floatx()
is a
"float32"
unless set to different value
(via config_set_floatx()
). If a keras$DTypePolicy
is
provided, then the compute_dtype
will be utilized.
__call__(y_true, y_pred, sample_weight=NULL)
Call the loss instance as a function, optionally with sample_weight
.
get_config()
dtype
All R function custom methods (public and private) will have the following symbols in scope:
self
: The custom class instance.
super
: The custom class superclass.
private
: An R environment specific to the class instance.
Any objects assigned here are invisible to the Keras framework.
__class__
and as.symbol(classname)
: the custom class type object.
Example subclass implementation:
loss_custom_mse <- Loss(
classname = "CustomMeanSquaredError",
call = function(y_true, y_pred) {
op_mean(op_square(y_pred - y_true), axis = -1)
}
)# Usage in compile()
model <- keras_model_sequential(input_shape = 10) |> layer_dense(10)
model |> compile(loss = loss_custom_mse())
# Standalone usage
mse <- loss_custom_mse(name = "my_custom_mse_instance")
y_true <- op_arange(20) |> op_reshape(c(4, 5))
y_pred <- op_arange(20) |> op_reshape(c(4, 5)) * 2
(loss <- mse(y_true, y_pred))
## tf.Tensor(123.5, shape=(), dtype=float32)
loss2 <- (y_pred - y_true)^2 |>
op_mean(axis = -1) |>
op_mean()stopifnot(all.equal(as.array(loss), as.array(loss2)))
sample_weight <-array(c(.25, .25, 1, 1))
(weighted_loss <- mse(y_true, y_pred, sample_weight = sample_weight))
## tf.Tensor(112.8125, shape=(), dtype=float32)
weighted_loss2 <- (y_true - y_pred)^2 |>
op_mean(axis = -1) |>
op_multiply(sample_weight) |>
op_mean()stopifnot(all.equal(as.array(weighted_loss),
as.array(weighted_loss2)))
Other losses:
loss_binary_crossentropy()
loss_binary_focal_crossentropy()
loss_categorical_crossentropy()
loss_categorical_focal_crossentropy()
loss_categorical_hinge()
loss_circle()
loss_cosine_similarity()
loss_ctc()
loss_dice()
loss_hinge()
loss_huber()
loss_kl_divergence()
loss_log_cosh()
loss_mean_absolute_error()
loss_mean_absolute_percentage_error()
loss_mean_squared_error()
loss_mean_squared_logarithmic_error()
loss_poisson()
loss_sparse_categorical_crossentropy()
loss_squared_hinge()
loss_tversky()
metric_binary_crossentropy()
metric_binary_focal_crossentropy()
metric_categorical_crossentropy()
metric_categorical_focal_crossentropy()
metric_categorical_hinge()
metric_hinge()
metric_huber()
metric_kl_divergence()
metric_log_cosh()
metric_mean_absolute_error()
metric_mean_absolute_percentage_error()
metric_mean_squared_error()
metric_mean_squared_logarithmic_error()
metric_poisson()
metric_sparse_categorical_crossentropy()
metric_squared_hinge()