Learn R Programming

superml (version 0.4.0)

RFTrainer: Random Forest Trainer

Description

Trains a Random Forest model. A random forest is a meta estimator that fits a number of decision tree classifiers on various sub-samples of the dataset and use averaging to improve the predictive accuracy and control over-fitting. This implementation uses ranger R package which provides faster model training.

Usage

RFTrainer

Format

R6Class object.

Usage

For usage details see Methods, Arguments and Examples sections.

bst = RFTrainer$new(n_estimators=100, max_features="auto", max_depth=5, min_node_size=1,
                              criterion, classification=1, class_weights, verbose=TRUE,
                              seed=42, always_split)
bst$fit(X_train, "target")
prediction <- bst$predict(X_test)

Methods

$new()

Initialises an instance of random forest model

$fit()

fit model to an input train data and trains the model.

$predict()

returns predictions by fitting the trained model on test data.

$get_importance()

Get feature importance from the model

Arguments

n_estimators

the number of trees in the forest, default= 100

max_features

the number of features to consider when looking for the best split. Possible values are auto(default) takes sqrt(num_of_features), sqrt same as auto, log takes log(num_of_features), none takes all features

max_depth

the maximum depth of each tree

min_node_size

the minumum number of samples required to split an internal node

criterion

the function to measure the quality of split. For classification, gini is used which is a measure of gini index. For regression, the variance of responses is used.

classification

whether to train for classification (1) or regression (0)

class_weights

weights associated with the classes for sampling of training observation

verbose

show computation status and estimated runtime

always_split

vector of feature names to be always used for splitting

seed

seed value

importance

Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression. Defaults to "impurity"

Examples

Run this code
# NOT RUN {
data("iris")
bst <- RFTrainer$new(n_estimators=50,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')
predictions <- bst$predict(iris)
bst$get_importance()
# }

Run the code above in your browser using DataLab