Learn R Programming

superml (version 0.5.7)

RFTrainer: Random Forest Trainer

Description

Trains a random forest model.

Arguments

Public fields

n_estimators

the number of trees in the forest, default= 100

max_features

the number of features to consider when looking for the best split. Possible values are auto(default) takes sqrt(num_of_features), sqrt same as auto, log takes log(num_of_features), none takes all features

max_depth

the maximum depth of each tree

min_node_size

the minumum number of samples required to split an internal node

criterion

the function to measure the quality of split. For classification, gini is used which is a measure of gini index. For regression, the variance of responses is used.

classification

whether to train for classification (1) or regression (0)

verbose

show computation status and estimated runtime

seed

seed value

class_weights

weights associated with the classes for sampling of training observation

always_split

vector of feature names to be always used for splitting

importance

Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression. Defaults to "impurity"

Methods


Method new()

Usage

RFTrainer$new(
  n_estimators,
  max_depth,
  max_features,
  min_node_size,
  classification,
  class_weights,
  always_split,
  verbose,
  save_model,
  seed,
  importance
)

Arguments

n_estimators

integer, the number of trees in the forest, default= 100

max_depth

integer, the maximum depth of each tree

max_features

integer, the number of features to consider when looking for the best split. Possible values are auto(default) takes sqrt(num_of_features), sqrt same as auto, log takes log(num_of_features), none takes all features

min_node_size

integer, the minumum number of samples required to split an internal node

classification

integer, whether to train for classification (1) or regression (0)

class_weights

weights associated with the classes for sampling of training observation

always_split

vector of feature names to be always used for splitting

verbose

logical, show computation status and estimated runtime

save_model

logical, whether to save model

seed

integer, seed value

importance

Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression. Defaults to "impurity"

Details

Create a new `RFTrainer` object.

Returns

A `RFTrainer` object.

Examples

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)


Method fit()

Usage

RFTrainer$fit(X, y)

Arguments

X

data.frame containing train features

y

character, name of the target variable

Details

Trains the random forest model

Returns

NULL, trains and saves the model in memory

Examples

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')


Method predict()

Usage

RFTrainer$predict(df)

Arguments

df

data.frame containing test features

Details

Return predictions from random forest model

Returns

a vector containing predictions

Examples

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')
predictions <- bst$predict(iris)


Method get_importance()

Usage

RFTrainer$get_importance()

Details

Returns feature importance from the model

Returns

a data frame containing feature predictions

Examples

data("iris")
bst <- RFTrainer$new(n_estimators=50,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')
predictions <- bst$predict(iris)
bst$get_importance()


Method clone()

The objects of this class are cloneable with this method.

Usage

RFTrainer$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Details

Trains a Random Forest model. A random forest is a meta estimator that fits a number of decision tree classifiers on various sub-samples of the dataset and use averaging to improve the predictive accuracy and control over-fitting. This implementation uses ranger R package which provides faster model training.

Examples

Run this code

## ------------------------------------------------
## Method `RFTrainer$new`
## ------------------------------------------------

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)

## ------------------------------------------------
## Method `RFTrainer$fit`
## ------------------------------------------------

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')

## ------------------------------------------------
## Method `RFTrainer$predict`
## ------------------------------------------------

data("iris")
bst <- RFTrainer$new(n_estimators=10,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')
predictions <- bst$predict(iris)

## ------------------------------------------------
## Method `RFTrainer$get_importance`
## ------------------------------------------------

data("iris")
bst <- RFTrainer$new(n_estimators=50,
                     max_depth=4,
                     classification=1,
                     seed=42,
                     verbose=TRUE)
bst$fit(iris, 'Species')
predictions <- bst$predict(iris)
bst$get_importance()

Run the code above in your browser using DataLab