Learn R Programming

iml (version 0.6.0)

TreeSurrogate: Decision tree surrogate model

Description

TreeSurrogate fits a decision tree on the predictions of a prediction model.

Format

R6Class object.

Usage

tree = TreeSurrogate$new(predictor, maxdepth = 2, tree.args = NULL, run = TRUE)

plot(tree) predict(tree, newdata) tree$results print(tree)

Arguments

For TreeSurrogate$new():

predictor:

(Predictor) The object (created with Predictor$new()) holding the machine learning model and the data.

maxdepth:

(`numeric(1)`) The maximum depth of the tree. Default is 2.

run:

(`logical(1)`) Should the Interpretation method be run?

tree.args:

(named list) Further arguments for ctree.

Fields

maxdepth:

(`numeric(1)`) The maximum tree depth.

predictor:

(Predictor) The prediction model that was analysed.

r.squared:

(`numeric(1|n.classes)`) R squared measures how well the decision tree approximates the underlying model. It is calculated as 1 - (variance of prediction differences / variance of black box model predictions). For the multi-class case, r.squared contains one measure per class.

results:

(data.frame) Data.frame with sampled feature X together with the leaf node information (columns .node and .path) and the predicted \(\hat{y}\) for tree and machine learning model (columns starting with .y.hat).

tree:

(party) The fitted tree. See also ctree.

Methods

plot()

method to plot the leaf nodes of the surrogate decision tree. See plot.TreeSurrogate.

predict()

method to predict new data with the tree. See also predict.TreeSurrogate

run()

[internal] method to run the interpretability method. Use obj$run(force = TRUE) to force a rerun.

clone()

[internal] method to clone the R6 object.

initialize()

[internal] method to initialize the R6 object.

Details

A conditional inference tree is fitted on the predicted \(\hat{y}\) from the machine learning model and the data. The partykit package and function are used to fit the tree. By default a tree of maximum depth of 2 is fitted to improve interpretability.

References

Craven, M., & Shavlik, J. W. (1996). Extracting tree-structured representations of trained networks. In Advances in neural information processing systems (pp. 24-30).

See Also

predict.TreeSurrogate plot.TreeSurrogate

For the tree implementation ctree

Examples

Run this code
# NOT RUN {
if (require("randomForest")) {
# Fit a Random Forest on the Boston housing data set
data("Boston", package  = "MASS")
rf = randomForest(medv ~ ., data = Boston, ntree = 50)
# Create a model object
mod = Predictor$new(rf, data = Boston[-which(names(Boston) == "medv")]) 

# Fit a decision tree as a surrogate for the whole random forest
dt = TreeSurrogate$new(mod)

# Plot the resulting leaf nodes
plot(dt) 

# Use the tree to predict new data
predict(dt, Boston[1:10,])

# Extract the results
dat = dt$results
head(dat)


# It also works for classification
rf = randomForest(Species ~ ., data = iris, ntree = 50)
X = iris[-which(names(iris) == "Species")]
mod = Predictor$new(rf, data = X, type = "prob")

# Fit a decision tree as a surrogate for the whole random forest
dt = TreeSurrogate$new(mod, maxdepth=2)

# Plot the resulting leaf nodes
plot(dt) 

# If you want to visualise the tree directly:
plot(dt$tree)

# Use the tree to predict new data
set.seed(42)
iris.sample = X[sample(1:nrow(X), 10),]
predict(dt, iris.sample)
predict(dt, iris.sample, type = "class")

# Extract the dataset
dat = dt$results
head(dat)
}
# }

Run the code above in your browser using DataLab