Learn R Programming

policytree (version 1.1.1)

predict.policy_tree: Predict method for policy_tree

Description

Predict values based on fitted policy_tree object.

Usage

# S3 method for policy_tree
predict(object, newdata, type = c("action.id", "node.id"), ...)

Arguments

object

policy_tree object

newdata

Points at which predictions should be made. Note that this matrix should have the same number of columns as the training matrix, and that the columns must appear in the same order.

type

The type of prediction required, "action.id" is the action id and "node.id" is the integer id of the leaf node the sample falls into. Default is "action.id".

...

Additional arguments (currently ignored).

Value

A vector of predictions. For type = "action.id" each element is an integer from 1 to d where d is the number of columns in the reward matrix. For type = "node.id" each element is an integer corresponding to the node the sample falls into (level-ordered).

Examples

Run this code
# NOT RUN {
# Fit a depth two tree on doubly robust treatment effect estimates from a causal forest.
n <- 10000
p <- 10
# Rounding down continuous covariates decreases runtime.
X <- round(matrix(rnorm(n * p), n, p), 2)
colnames(X) <- make.names(1:p)
W <- rbinom(n, 1, 1 / (1 + exp(X[, 3])))
tau <- 1 / (1 + exp((X[, 1] + X[, 2]) / 2)) - 0.5
Y <- X[, 3] + W * tau + rnorm(n)
c.forest <- grf::causal_forest(X, Y, W)
dr.scores <- double_robust_scores(c.forest)

tree <- policy_tree(X, dr.scores, 2)
tree

# Predict treatment assignment.
predicted <- predict(tree, X)

plot(X[, 1], X[, 2], col = predicted)
legend("topright", c("control", "treat"), col = c(1, 2), pch = 19)
abline(0, -1, lty = 2)

# Predict the leaf assigned to each sample.
node.id <- predict(tree, X, type = "node.id")
# Can be reshaped to a list of samples per leaf node with `split`.
samples.per.leaf <- split(1:n, node.id)

# The value of all arms (along with SEs) by each leaf node.
values <- aggregate(dr.scores, by = list(leaf.node = node.id),
                    FUN = function(x) c(mean = mean(x), se = sd(x) / sqrt(length(x))))
print(values, digits = 2)

# Take cost of treatment into account by offsetting the objective
# with an estimate of the average treatment effect.
# See section 5.1 in Athey and Wager (2021) for more details, including
# suggestions on using cross-validation to assess the accuracy of the learned policy.
ate <- grf::average_treatment_effect(c.forest)
cost.offset <- ate[["estimate"]]
tree.cost <- policy_tree(X, dr.scores - cost.offset, 2)

# If there are too many covariates to make tree search computationally feasible,
# one can consider for example only the top 5 features according to GRF's variable importance.
var.imp <- grf::variable_importance(c.forest)
top.5 <- order(var.imp, decreasing = TRUE)[1:5]
tree.top5 <- policy_tree(X[, top.5], dr.scores, 2, split.step = 50)
# }

Run the code above in your browser using DataLab