Abstract class used as template to define customized functions to control the computational nuances of train function.
new()
Function used to initialize the object parameters during execution time.
TrainFunction$new(
method,
number,
savePredictions,
classProbs,
allowParallel,
verboseIter,
seed
)
method
The resampling method: "boot", "boot632", "optimism_boot", "boot_all", "cv", "repeatedcv", "LOOCV", "LGOCV" (for repeated training/test splits), "none" (only fits one model to the entire training set), "oob" (only for random forest, bagged trees, bagged earth, bagged flexible discriminant analysis, or conditional tree forest models), timeslice, "adaptive_cv", "adaptive_boot" or "adaptive_LGOCV"
number
Either the number of folds or number of resampling iterations
savePredictions
An indicator of how much of the hold-out predictions for each resample should be saved. Values can be either "all", "final", or "none". A logical value can also be used that convert to "all" (for true) or "none" (for false). "final" saves the predictions for the optimal tuning parameters.
classProbs
A logical value. Should class probabilities be computed for classification models (along with predicted values) in each resample?
allowParallel
A logical value. If a parallel backend is loaded and available, should the function use it?
verboseIter
A logical for printing a training log.
seed
An optional integer that will be used to set the seed during model training stage.
create()
Creates a trainControl
requires for the
training stage.
TrainFunction$create(summaryFunction, search.method = "grid", class.probs)
summaryFunction
An object inherited from
SummaryFunction
class.
search.method
Either "grid" or "random", describing how the tuning parameter grid is determined.
class.probs
A logical indicating if class probabilities should be computed for classification models (along with predicted values) in each resample.
getResamplingMethod()
Returns the resampling method used during training staged.
TrainFunction$getResamplingMethod()
getNumberFolds()
Returns the number or folds or number of iterations used during training.
TrainFunction$getNumberFolds()
getSavePredictions()
Indicates if the predictions for each resample should be saved.
TrainFunction$getSavePredictions()
getClassProbs()
Indicates if class probabilities should be computed for classification models in each resample.
TrainFunction$getClassProbs()
A logical value.
getAllowParallel()
Determines if model training is performed in parallel.
TrainFunction$getAllowParallel()
getVerboseIter()
Determines if training log should be printed.
TrainFunction$getVerboseIter()
getTrFunction()
Function used to return the
trainControl
object.
TrainFunction$getTrFunction()
A trainControl
object.
getMeasures()
Returns the measures used to optimize model hyperparameters.
TrainFunction$getMeasures()
A character vector.
getType()
Obtains the type of classification problem ("Bi-class" or "Multi-class").
TrainFunction$getType()
A character vector with length 1. Either "Bi-class" or "Multi-class".
getSeed()
Indicates seed used during model training stage.
TrainFunction$getSeed()
setSummaryFunction()
Function used to change the SummaryFunction
used in the training stage.
TrainFunction$setSummaryFunction(summaryFunction)
summaryFunction
An object inherited from
SummaryFunction
class.
setClassProbs()
The function allows changing the class computation capabilities.
TrainFunction$setClassProbs(class.probs)
class.probs
A logical indicating if class probabilities should be computed for classification models (along with predicted values) in each resample
clone()
The objects of this class are cloneable with this method.
TrainFunction$clone(deep = FALSE)
deep
Whether to make a deep clone.
TwoClass