Fits a varying intercept/random effect BART model.
rbart_vi(
formula, data, test, subset, weights, offset, offset.test = offset,
group.by, group.by.test, prior = cauchy,
sigest = NA_real_, sigdf = 3.0, sigquant = 0.90,
k = 2.0,
power = 2.0, base = 0.95,
n.trees = 75L,
n.samples = 1500L, n.burn = 1500L,
n.chains = 4L, n.threads = min(dbarts::guessNumCores(), n.chains),
combineChains = FALSE,
n.cuts = 100L, useQuantiles = FALSE,
n.thin = 5L, keepTrainingFits = TRUE,
printEvery = 100L, printCutoffs = 0L,
verbose = TRUE,
keepTrees = TRUE, keepCall = TRUE,
seed = NA_integer_,
keepSampler = keepTrees,
keepTestFits = TRUE,
callback = NULL,
...)# S3 method for rbart
plot(
x, plquants = c(0.05, 0.95), cols = c('blue', 'black'), ...)
# S3 method for rbart
fitted(
object,
type = c("ev", "ppd", "bart", "ranef"),
sample = c("train", "test"),
...)
# S3 method for rbart
extract(
object,
type = c("ev", "ppd", "bart", "ranef", "trees"),
sample = c("train", "test"),
combineChains = TRUE,
...)
# S3 method for rbart
predict(
object, newdata, group.by, offset,
type = c("ev", "ppd", "bart", "ranef"),
combineChains = TRUE,
...)
# S3 method for rbart
residuals(object, ...)
An object of class rbart
. Contains all of the same elements of an object of class bart
, as well as the elements:
Samples from the posterior of the random effects. A array/matrix of posterior samples. The \((k, l, j)\) value is the \(l\)th draw of the posterior of the random effect for group \(j\) (i.e. \(\alpha^*_j\)) corresponding to chain \(k\). When n.chains
is one or combineChains
is TRUE
, the result is a collapsed down to a matrix.
Posterior mean of random effects, derived by taking mean across group index of samples.
Matrix of posterior samples of tau
, the standard deviation of the random effects. Dimensions are equal to the number of chains times the numbers of samples unless n.chains
is one or combineChains
is TRUE
.
first.tau
Burn-in draws of tau
.
callback
Optional results of callback
function.
Grouping factor. Can be an integer vector/factor, or a reference to such in data
.
Grouping factor for test data, of the same type as group.by
. Can be missing.
A function or symbolic reference to built-in priors. Determines the prior over the standard deviation of the random effects. Supplied functions take two arguments, x
- the standard deviation, and rel.scale
- the standard deviation of the response variable before random effects are fit. Built in priors are cauchy
with a scale of 2.5 times the relative scale and gamma
with a shape of 2.5 and scale of 2.5 times the relative scale.
The number of tree jumps taken for every stored sample, but also the number of samples from the posterior of the standard deviation of the random effects before one is kept.
Logical where, if false, test fits are obtained while running but not returned. Useful with callback
.
Optional function of trainFits
, testFits
, ranef
, sigma
, and tau
. Called after every post-burn-in iteration and the results of which are collected and stored in the final object.
Same as in bart2
.
A fitted rbart
model.
Same as test
, but named to match predict
generic.
One of "ev"
, "ppd"
, "bart"
, "ranef"
, or "trees"
for the posterior of the expected value, posterior predictive distribution, non-parametric/BART component, random effect, or saved trees respectively. The expected value is the sum of the BART component and the random effects, while the posterior predictive distribution is a response sampled with that mean. To synergize with predict.glm
, "response"
can be used as a synonym for "value"
and "link"
can be used as a synonym for "bart"
. For additional details on tree extraction, see the corresponding subsection in bart
.
One of "train"
or "test"
, referring to the training or tests samples respectively.
Same as in plot.bart
.
Vincent Dorie: vdorie@gmail.com
Fits a BART model with additive random intercepts, one for each factor level of group.by
. For continuous responses:
\(y_i \sim N(f(x_i) + \alpha_{g[i]}, \sigma^2)\)
\(\alpha_j \sim N(0, \tau^2)\).
For binary outcomes the response model is changed to \(P(Y_i = 1) = \Phi(f(x_i) + \alpha_{g[i]})\). \(i\) indexes observations, \(g[i]\) is the group index of observation \(i\), \(f(x)\) and \(\sigma_y\) come from a BART model, and \(\alpha_j\) are the independent and identically distributed random intercepts. Draws from the posterior of \(tau\) are made using a slice sampler, with a width dynamically determined by assessing the curvature of the posterior distribution at its mode.
Predicting random effects for groups not in the training sample is supported by sampling from their posterior predictive distribution, that is a draw is taken from \(p(\alpha \mid y) = \int p(\alpha \mid \tau)p(\tau \mid y)d\alpha\). For out-of-sample groups in the test data, these random effect draws can be kept with the saved object. For those supplied to predict
, they cannot and may change for subsequent calls.
See the generics section of bart
.
bart
, dbarts
f <- function(x) {
10 * sin(pi * x[,1] * x[,2]) + 20 * (x[,3] - 0.5)^2 +
10 * x[,4] + 5 * x[,5]
}
set.seed(99)
sigma <- 1.0
n <- 100
x <- matrix(runif(n * 10), n, 10)
Ey <- f(x)
y <- rnorm(n, Ey, sigma)
n.g <- 10
g <- sample(n.g, length(y), replace = TRUE)
sigma.b <- 1.5
b <- rnorm(n.g, 0, sigma.b)
y <- y + b[g]
df <- as.data.frame(x)
colnames(df) <- paste0("x_", seq_len(ncol(x)))
df$y <- y
df$g <- g
## low numbers to reduce run time
rbartFit <- rbart_vi(y ~ . - g, df, group.by = g,
n.samples = 40L, n.burn = 10L, n.thin = 2L,
n.chains = 1L,
n.trees = 25L, n.threads = 1L)
Run the code above in your browser using DataLab