## A simple example of bagging conditional inference regression trees:
data(BloodBrain)
## Fit a model with the default values
ctreeFit <- function(x, y, ...)
{
library(party)
data <- as.data.frame(x)
data$y <- y
ctree(y~., data = data)
}
## Generate simple predictions of the outcome
ctreePred <- function(object, x)
{
predict(object, x)[,1]
}
## Take the median of the bagged predictions
ctreeAg <- function(x, type = NULL)
{
## x is a list of vectors, so we convert them to a matrix
preds <- do.call("cbind", x)
apply(preds, 1, median)
}
treebag <- bag(bbbDescr, logBBB, B = 10,
bagControl = bagControl(fit = ctreeFit,
predict = ctreePred,
aggregate = ctreeAg))
## An example of pooling posterior probabilities to generate class predictions
data(mdrr)
## remove some zero variance predictors and linear dependencies
mdrrDescr <- mdrrDescr[, -nearZeroVar(mdrrDescr)]
mdrrDescr <- mdrrDescr[, -findCorrelation(cor(mdrrDescr), .95)]
## The fit and predict functions are stright-forward:
ldaFit <- function(x, y, ...)
{
library(MASS)
lda(x, y, ...)
}
ldaPred <- function(object, x)
{
predict(object, x)$posterior
}
## For the aggregation function, we take the median of the bagged
## posterior probabilities and pick the largest as the class
ldaAg <- function(x, type = "class")
{
## The class probabilities come in as a list of matrices
## For each class, we can pool them then average over them
pooled <- x[[1]] & NA
classes <- colnames(pooled)
for(i in 1:ncol(pooled))
{
tmp <- lapply(x, function(y, col) y[,col], col = i)
tmp <- do.call("rbind", tmp)
pooled[,i] <- apply(tmp, 2, median)
}
if(type == "class")
{
out <- factor(classes[apply(pooled, 1, which.max)],
levels = classes)
} else out <- pooled
out
}
bagLDA <- bag(mdrrDescr, mdrrClass,
B = 10,
vars = 10,
bagControl = bagControl(fit = ldaFit,
predict = ldaPred,
aggregate = ldaAg))
basicLDA <- train(mdrrDescr, mdrrClass, "lda")
bagLDA2 <- train(mdrrDescr, mdrrClass,
"bag",
B = 10,
bagControl(fit = ldaFit,
predict = ldaPred,
aggregate = ldaAg),
tuneGrid = data.frame(.vars = c((1:10)*10 , ncol(mdrrDescr))))
Run the code above in your browser using DataLab