Learn R Programming

iml (version 0.6.0)

Predictor: Predictor object

Description

A Predictor object holds any machine learning model (mlr, caret, randomForest, ...) and the data to be used of analysing the model. The interpretation methods in the iml package need the machine learning model to be wrapped in a Predictor object.

Format

R6Class object.

Usage

model = Predictor$new(model = NULL, data, y = NULL, class=NULL, 
  predict.fun = NULL, type = NULL)

model$predict(newdata)

Arguments

model:

(any) The machine learning model. Recommended are models from mlr and caret. Other machine learning with a S3 predict functions work as well, but less robust (e.g. randomForest).

data:

(data.frame) The data to be used for analysing the prediction model.

y:

((`character(1)`) | numeric | factor) The target vector or (preferably) the name of the target column in the data argument.

class:

(`character(1)`) The class column to be returned in case of multiclass output.

predict.fun:

(function) The function to predict newdata. Only needed if model is not a model from mlr or caret package.

type:

(`character(1)`) This argument is passed to the prediction function of the model. The classic use case is to say type="prob" for classification models. For example for caret models or the most S3 predict methods. If both predict.fun and type are used, then type is passed as an argument to predict.fun.

batch.size:

(`numeric(1)`) The maximum number of rows to be input the model for prediction at once. Currently only respected for FeatureImp, Partial and Interaction.

Fields

class:

(`character(1)`) The class column to be returned.

data:

(data.frame) data object with the data for the model interpretation.

prediction.colnames:

(character) The column names of the predictions.

task:

(`character(1)`) The inferred prediction task: "classification" or "regression".

Methods

predict(newdata)

method to predict new data with the machine learning model.

clone()

[internal] method to clone the R6 object.

initialize()

[internal] method to initialize the R6 object.

Details

A Predictor object is a container for the prediction model and the data. This ensures that the machine learning model can be analysed robustly.

Note: In case of classification, the model should return one column per class with the class probability.

Examples

Run this code
# NOT RUN {
if (require("mlr")) {
task = makeClassifTask(data = iris, target = "Species")
learner = makeLearner("classif.rpart", minsplit = 7, predict.type = "prob")
mod.mlr = train(learner, task)
mod = Predictor$new(mod.mlr, data = iris)
mod$predict(iris[1:5,])

mod = Predictor$new(mod.mlr, data = iris, class = "setosa")
mod$predict(iris[1:5,])
}

if (require("randomForest")) {
rf = randomForest(Species ~ ., data = iris, ntree = 20)


mod = Predictor$new(rf, data = iris, type = "prob")
mod$predict(iris[50:55,])

# Feature importance needs the target vector, which needs to be supplied: 
mod = Predictor$new(rf, data = iris, y = "Species", type = "prob")
}
# }

Run the code above in your browser using DataLab