Perform exact K-fold cross-validation by refitting the model \(K\) times each leaving out one-\(K\)th of the original data.
# S3 method for brmsfit
kfold(x, ..., compare = TRUE, K = 10, Ksub = NULL,
folds = NULL, group = NULL, exact_loo = NULL, resp = NULL,
model_names = NULL, save_fits = FALSE)kfold(x, ...)
A fitted model object.
More fitted model objects or further arguments passed to the underlying post-processing functions.
A flag indicating if the information criteria
of the models should be compared to each other
via compare_ic
.
The number of subsets of equal (if possible) size
into which the data will be partitioned for performing
\(K\)-fold cross-validation. The model is refit K
times, each time
leaving out one of the K
subsets. If K
is equal to the total
number of observations in the data then \(K\)-fold cross-validation is
equivalent to exact leave-one-out cross-validation.
Optional number of subsets (of those subsets defined by K
)
to be evaluated. If NULL
(the default), \(K\)-fold cross-validation
will be performed on all subsets. If Ksub
is a single integer,
Ksub
subsets (out of all K
) subsets will be randomly chosen.
If Ksub
consists of multiple integers or a one-dimensional array
(created via as.array
) potentially of length one, the corresponding
subsets will be used. This argument is primarily useful, if evaluation of
all subsets is infeasible for some reason.
Determines how the subsets are being constructed.
Possible values are NULL
(the default), "stratified"
,
"balanced"
, or "loo"
. May also be a vector of length
equal to the number of observations in the data. Alters the way
group
is handled. More information is provided in the 'Details'
section.
Optional name of a grouping variable or factor in the model.
What exactly is done with this variable depends on argument folds
.
More information is provided in the 'Details' section.
Deprecated! Please use folds = "loo"
instead.
Optional names of response variables. If specified, predictions are performed only for the specified response variables.
If NULL
(the default) will use model names
derived from deparsing the call. Otherwise will use the passed
values as model names.
If TRUE
, components fits
and data
are
added to the returned object. fits
stores the cross-validated
brmsfit
objects and the indices of the omitted and predicted
observations for each fold. data
returns the full data frame
used in the cross-validation. Defaults to FALSE
.
kfold
returns an object that has a similar structure as the
objects returned by the loo
and waic
methods.
The kfold
function performs exact \(K\)-fold
cross-validation. First the data are partitioned into \(K\) folds
(i.e. subsets) of equal (or as close to equal as possible) size by default.
Then the model is refit \(K\) times, each time leaving out one of the
K
subsets. If \(K\) is equal to the total number of observations
in the data then \(K\)-fold cross-validation is equivalent to exact
leave-one-out cross-validation (to which loo
is an efficient
approximation). The compare_ic
function is also compatible with
the objects returned by kfold
.
The subsets can be constructed in multiple different ways:
If both folds
and group
are NULL
, the subsets
are randomly chosen so that they have equal (or as close to equal as
possible) size.
If folds
is NULL
but group
is specified, the
data is split up into subsets, each time omitting all observations of one
of the factor levels, while ignoring argument K
.
If folds = "stratified"
the subsets are stratified after
group
using loo::kfold_split_stratified
.
If folds = "balanced"
the subsets are balanced by
group
using loo::kfold_split_balanced
.
If folds = "loo"
exact leave-one-out cross-validation
will be performed and K
will be ignored. Further, if group
is specified, all observations corresponding to the factor level of the
currently predicted single value are omitted. Thus, in this case, the
predicted values are only a subset of the omitted ones.
If folds
is a numeric vector, it must contain one element per
observation in the data. Each element of the vector is an integer in
1:K
indicating to which of the K
folds the corresponding
observation belongs. There are some convenience functions available in
the loo package that create integer vectors to use for this purpose
(see the Examples section below and also the
kfold-helpers page).
# NOT RUN {
fit1 <- brm(count ~ log_Age_c + log_Base4_c * Trt +
(1|patient) + (1|obs),
data = epilepsy, family = poisson())
# throws warning about some pareto k estimates being too high
(loo1 <- loo(fit1))
# perform 10-fold cross validation
(kfold1 <- kfold(fit1, chains = 1)
# }
# NOT RUN {
# }
Run the code above in your browser using DataLab