Abstract base class for counterfactual explanation methods for classifcation tasks.
CounterfactualMethodClassif
can only be initialized for classification tasks. Child classes inherit the (public)
$find_counterfactuals()
method, which calls a (private) $run()
method. This $run()
method should be implemented
by the child classes and return the counterfactuals as a data.table
(preferably) or a data.frame
.
Child classes: MOCClassif, WhatIfClassif, NICEClassif
counterfactuals::CounterfactualMethod
-> CounterfactualMethodClassif
Inherited methods
new()
Creates a new CounterfactualMethodClassif
object.
CounterfactualMethodClassif$new(
predictor,
lower = NULL,
upper = NULL,
distance_function = NULL
)
predictor
(Predictor)
The object (created with iml::Predictor$new()
) holding the machine learning model and the data.
lower
(numeric()
| NULL
)
Vector of minimum values for numeric features.
If NULL
(default), the element for each numeric feature in lower
is taken as its minimum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
upper
(numeric()
| NULL
)
Vector of maximum values for numeric features.
If NULL
(default), the element for each numeric feature in upper
is taken as its maximum value in predictor$data$X
.
If not NULL
, it should be named with the corresponding feature names.
distance_function
(function()
| NULL
)
A distance function that may be used by the leaf classes.
If specified, the function must have three arguments: x
, y
, and data
and return a double
matrix with nrow(x)
rows and nrow(y)
columns.
find_counterfactuals()
Runs the counterfactual method and returns the counterfactuals.
It searches for counterfactuals that have a predicted probability in the interval desired_prob
for the
desired_class
.
CounterfactualMethodClassif$find_counterfactuals(
x_interest,
desired_class = NULL,
desired_prob = c(0.5, 1)
)
x_interest
(data.table(1)
| data.frame(1)
)
A single row with the observation of interest.
desired_class
(character(1)
| NULL
)
The desired class. If NULL
(default) then predictor$class
is taken.
desired_prob
(numeric(1)
| numeric(2)
)
The desired predicted probability of the desired_class
. It can be a numeric scalar or a vector with two
numeric values that specify a probability interval.
For hard classification tasks this can be set to 0
or 1
, respectively.
A scalar is internally converted to an interval.
A Counterfactuals object containing the results.
clone()
The objects of this class are cloneable with this method.
CounterfactualMethodClassif$clone(deep = FALSE)
deep
Whether to make a deep clone.