Learn R Programming

BART (version 2.9.9)

mc.crisk.pwbart: Predicting new observations with a previously fitted BART model

Description

BART is a Bayesian “sum-of-trees” model.
For a numeric response \(y\), we have \(y = f(x) + \epsilon\), where \(\epsilon \sim N(0,\sigma^2)\).

\(f\) is the sum of many tree models. The goal is to have very flexible inference for the uknown function \(f\).

In the spirit of “ensemble models”, each tree is constrained by a prior to be a weak learner so that it contributes a small amount to the overall fit.

Usage

mc.crisk.pwbart( x.test, x.test2,
                 treedraws, treedraws2,
                 binaryOffset=0, binaryOffset2=0,
                 mc.cores=2L, type='pbart',
                 transposed=FALSE, nice=19L
               )

Value

Returns an object of type criskbart which is essentially a list with components:

yhat.test

A matrix with ndpost rows and nrow(x.test) columns. Each row corresponds to a draw \(f^*\) from the posterior of \(f\) and each column corresponds to a row of x.train. The \((i,j)\) value is \(f^*(x)\) for the \(i^{th}\) kept draw of \(f\) and the \(j^{th}\) row of x.train.
Burn-in is dropped.

surv.test

test data fits for survival probability.

surv.test.mean

mean of surv.test over the posterior samples.

prob.test

The probability of suffering cause 1 which is occasionally useful, e.g., in calculating the concordance.

prob.test2

The probability of suffering cause 2 which is occasionally useful, e.g., in calculating the concordance.

cif.test

The cumulative incidence function of cause 1, \(F_1(t, x)\), where x's are the rows of the test data.

cif.test2

The cumulative incidence function of cause 2, \(F_2(t, x)\), where x's are the rows of the test data.

yhat.test.mean

test data fits = mean of yhat.test columns.

cif.test.mean

mean of cif.test columns for cause 1.

cif.test2.mean

mean of cif.test2 columns for cause 2.

Arguments

x.test

Matrix of covariates to predict \(y\) for cause 1.

x.test2

Matrix of covariates to predict \(y\) for cause 2.

treedraws

$treedraws for cause 1.

treedraws2

$treedraws for cause 2.

binaryOffset

Mean to add on to \(y\) prediction for cause 1.

binaryOffset2

Mean to add on to \(y\) prediction for cause 2.

mc.cores

Number of threads to utilize.

type

Whether to employ Albert-Chib, 'pbart', or Holmes-Held, 'lbart'.

transposed

When running pwbart or mc.pwbart in parallel, it is more memory-efficient to transpose x.test prior to calling the internal versions of these functions.

nice

Set the job niceness. The default niceness is 19: niceness goes from 0 (highest) to 19 (lowest).

Details

BART is an Bayesian MCMC method. At each MCMC interation, we produce a draw from the joint posterior \((f,\sigma) | (x,y)\) in the numeric \(y\) case and just \(f\) in the binary \(y\) case.

Thus, unlike a lot of other modelling methods in R, we do not produce a single model object from which fits and summaries may be extracted. The output consists of values \(f^*(x)\) (and \(\sigma^*\) in the numeric case) where * denotes a particular draw. The \(x\) is either a row from the training data (x.train) or the test data (x.test).

See Also

pwbart, crisk.bart, mc.crisk.bart

Examples

Run this code

data(transplant)

delta <- (as.numeric(transplant$event)-1)
## recode so that delta=1 is cause of interest; delta=2 otherwise
delta[delta==1] <- 4
delta[delta==2] <- 1
delta[delta>1] <- 2
table(delta, transplant$event)

times <- pmax(1, ceiling(transplant$futime/7)) ## weeks
##times <- pmax(1, ceiling(transplant$futime/30.5)) ## months
table(times)

typeO <- 1*(transplant$abo=='O')
typeA <- 1*(transplant$abo=='A')
typeB <- 1*(transplant$abo=='B')
typeAB <- 1*(transplant$abo=='AB')
table(typeA, typeO)

x.train <- cbind(typeO, typeA, typeB, typeAB)

x.test <- cbind(1, 0, 0, 0)
dimnames(x.test)[[2]] <- dimnames(x.train)[[2]]

## parallel::mcparallel/mccollect do not exist on windows
if(.Platform$OS.type=='unix') {
##test BART with token run to ensure installation works
        post <- mc.crisk.bart(x.train=x.train, times=times, delta=delta,
                               seed=99, mc.cores=2, nskip=5, ndpost=5,
                               keepevery=1)

        pre <- surv.pre.bart(x.train=x.train, x.test=x.test,
                             times=times, delta=delta)

        K <- post$K

        pred <- mc.crisk.pwbart(pre$tx.test, pre$tx.test,
                                post$treedraws, post$treedraws2,
                                post$binaryOffset, post$binaryOffset2)
}

if (FALSE) {

## run one long MCMC chain in one process
## set.seed(99)
## post <- crisk.bart(x.train=x.train, times=times, delta=delta, x.test=x.test)

## in the interest of time, consider speeding it up by parallel processing
## run "mc.cores" number of shorter MCMC chains in parallel processes
post <- mc.crisk.bart(x.train=x.train,
                       times=times, delta=delta,
                       x.test=x.test, seed=99, mc.cores=8)

check <- mc.crisk.pwbart(post$tx.test, post$tx.test,
                          post$treedraws, post$treedraws2,
                          post$binaryOffset,
                          post$binaryOffset2, mc.cores=8)
## check <- predict(post, newdata=post$tx.test, newdata2=post$tx.test2,
##                  mc.cores=8)

print(c(post$surv.test.mean[1], check$surv.test.mean[1],
        post$surv.test.mean[1]-check$surv.test.mean[1]), digits=22)

print(all(round(post$surv.test.mean, digits=9)==
    round(check$surv.test.mean, digits=9)))

print(c(post$cif.test.mean[1], check$cif.test.mean[1],
        post$cif.test.mean[1]-check$cif.test.mean[1]), digits=22)

print(all(round(post$cif.test.mean, digits=9)==
    round(check$cif.test.mean, digits=9)))

print(c(post$cif.test2.mean[1], check$cif.test2.mean[1],
        post$cif.test2.mean[1]-check$cif.test2.mean[1]), digits=22)

print(all(round(post$cif.test2.mean, digits=9)==
    round(check$cif.test2.mean, digits=9)))

}

Run the code above in your browser using DataLab