Abstract base class for counterfactual explanation methods for regression tasks.
CounterfactualMethodRegr
can only be initialized for regression tasks. Child classes inherit the (public)
$find_counterfactuals()
method, which calls a (private) $run()
method. This $run()
method should be implemented
by the child classes and return the counterfactuals as a data.table
(preferably) or a data.frame
.
Child classes: MOCRegr, WhatIfRegr, NICERegr
counterfactuals::CounterfactualMethod
-> CounterfactualMethodRegr
Inherited methods
new()
Creates a new CounterfactualMethodRegr object.
CounterfactualMethodRegr$new(
predictor,
lower = NULL,
upper = NULL,
distance_function = NULL
)
predictor
(Predictor)
The object (created with iml::Predictor$new()
) holding the machine learning model and the data.
lower
(numeric()
| NULL
)
Vector of minimum values for numeric features.
If NULL
(default), the element for each numeric feature in lower
is taken as its minimum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
upper
(numeric()
| NULL
)
Vector of maximum values for numeric features.
If NULL
(default), the element for each numeric feature in upper
is taken as its maximum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
distance_function
(function()
| NULL
)
A distance function that may be used by the leaf classes.
If specified, the function must have three arguments: x
, y
, and data
and return a double
matrix with nrow(x)
rows and nrow(y)
columns.
find_counterfactuals()
Runs the counterfactual method and returns the counterfactuals.
It searches for counterfactuals that have a predicted outcome in the interval desired_outcome
.
CounterfactualMethodRegr$find_counterfactuals(x_interest, desired_outcome)
x_interest
(data.table(1)
| data.frame(1)
)
A single row with the observation of interest.
desired_outcome
(numeric(1)
| numeric(2)
)
The desired predicted outcome. It can be a numeric scalar or a vector with two numeric values that specify an
outcome interval. A scalar is internally converted to an interval.
A Counterfactuals object containing the results.
clone()
The objects of this class are cloneable with this method.
CounterfactualMethodRegr$clone(deep = FALSE)
deep
Whether to make a deep clone.