Learn R Programming

superml (version 0.4.0)

LMTrainer: Linear Models Trainer

Description

Trains linear models such as Logistic, Lasso or Ridge regression model. It is built on glmnet R package. This class provides fit, predict, cross valdidation functions.

Usage

LMTrainer

Format

R6Class object.

Usage

For usage details see Methods, Arguments and Examples sections.

bst = LMTrainer$new(family, weights, alpha, lambda=100, standardize.response=FALSE)
bst$fit(X_train, "target")
prediction <- bst$predict(X_test)
bst$cv_model(X_train, "target", nfolds=4, parallel=TRUE)
cv_prediction <- bst$cv_predict(X_test)

Methods

$new()

Initialises an instance of random forest model

$fit()

fit model to an input train data (data frame or data table) and trains the model.

$predict()

returns predictions by fitting the trained model on test data.

$cv_model()

Using k-fold cross validation technique, finds the best value of lambda. type.measure is the loss to use for cross validation.

$cv_predict()

Using the best value of lambda, makes predictions on the test data

$get_importance()

Returns a matrix of feature coefficients as generated by Lasso

Arguments

family

type of regression to perform, values can be "gaussian" ,"binomial", "multinomial","mgaussian"

weights

observation weights. Can be total counts if responses are proportion matrices. Default is 1 for each observation

alpha

The elasticnet mixing parameter, alpha=1 is the lasso penalty, and alpha=0 the ridge penalty.

nlambda

the number of lambda values - default is 100

standardize.response

normalise the dependent variable between 0 and 1, default = FALSE

Examples

Run this code
# NOT RUN {
LINK <- "http://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data"
housing <- read.table(LINK)
names <- c("CRIM","ZN","INDUS","CHAS","NOX","RM","AGE","DIS",
           "RAD","TAX","PTRATIO","B","LSTAT","MEDV")
names(housing)  <-  names
lf <- LMTrainer$new(family = 'gaussian', alpha=1)
lf$fit(X = housing, y = 'MEDV')
predictions <- lf$predict(df = housing)


# cross validation model
lf$cv_model(X = housing, y = 'MEDV', nfolds = 5, parallel = FALSE)
predictions <- lf$cv_predict(df = housing)
coefs <- lf$get_importance()
# }

Run the code above in your browser using DataLab