Learn R Programming

loo (version 2.7.0)

kfold-helpers: Helper functions for K-fold cross-validation

Description

These functions can be used to generate indexes for use with K-fold cross-validation. See the Details section for explanations.

Usage

kfold_split_random(K = 10, N = NULL)

kfold_split_stratified(K = 10, x = NULL)

kfold_split_grouped(K = 10, x = NULL)

Value

An integer vector of length N where each element is an index in 1:K.

Arguments

K

The number of folds to use.

N

The number of observations in the data.

x

A discrete variable of length N with at least K levels (unique values). Will be coerced to a factor.

Details

kfold_split_random() splits the data into K groups of equal size (or roughly equal size).

For a categorical variable x kfold_split_stratified() splits the observations into K groups ensuring that relative category frequencies are approximately preserved.

For a grouping variable x, kfold_split_grouped() places all observations in x from the same group/level together in the same fold. The selection of which groups/levels go into which fold (relevant when when there are more groups than folds) is randomized.

Examples

Run this code
ids <- kfold_split_random(K = 5, N = 20)
print(ids)
table(ids)


x <- sample(c(0, 1), size = 200, replace = TRUE, prob = c(0.05, 0.95))
table(x)
ids <- kfold_split_stratified(K = 5, x = x)
print(ids)
table(ids, x)

grp <- gl(n = 50, k = 15, labels = state.name)
length(grp)
head(table(grp))

ids_10 <- kfold_split_grouped(K = 10, x = grp)
(tab_10 <- table(grp, ids_10))
colSums(tab_10)

ids_9 <- kfold_split_grouped(K = 9, x = grp)
(tab_9 <- table(grp, ids_9))
colSums(tab_9)

Run the code above in your browser using DataLab