Learn R Programming

sparklyr (version 1.7.1)

ml_evaluator: Spark ML - Evaluators

Description

A set of functions to calculate performance metrics for prediction models. Also see the Spark ML Documentation https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.package

Usage

ml_binary_classification_evaluator(
  x,
  label_col = "label",
  raw_prediction_col = "rawPrediction",
  metric_name = "areaUnderROC",
  uid = random_string("binary_classification_evaluator_"),
  ...
)

ml_binary_classification_eval( x, label_col = "label", prediction_col = "prediction", metric_name = "areaUnderROC" )

ml_multiclass_classification_evaluator( x, label_col = "label", prediction_col = "prediction", metric_name = "f1", uid = random_string("multiclass_classification_evaluator_"), ... )

ml_classification_eval( x, label_col = "label", prediction_col = "prediction", metric_name = "f1" )

ml_regression_evaluator( x, label_col = "label", prediction_col = "prediction", metric_name = "rmse", uid = random_string("regression_evaluator_"), ... )

Arguments

x

A spark_connection object or a tbl_spark containing label and prediction columns. The latter should be the output of sdf_predict.

label_col

Name of column string specifying which column contains the true labels or values.

raw_prediction_col

Raw prediction (a.k.a. confidence) column name.

metric_name

The performance metric. See details.

uid

A character string used to uniquely identify the ML estimator.

...

Optional arguments; currently unused.

prediction_col

Name of the column that contains the predicted label or value NOT the scored probability. Column should be of type Double.

Value

The calculated performance metric

Details

The following metrics are supported

  • Binary Classification: areaUnderROC (default) or areaUnderPR (not available in Spark 2.X.)

  • Multiclass Classification: f1 (default), precision, recall, weightedPrecision, weightedRecall or accuracy; for Spark 2.X: f1 (default), weightedPrecision, weightedRecall or accuracy.

  • Regression: rmse (root mean squared error, default), mse (mean squared error), r2, or mae (mean absolute error.)

ml_binary_classification_eval() is an alias for ml_binary_classification_evaluator() for backwards compatibility.

ml_classification_eval() is an alias for ml_multiclass_classification_evaluator() for backwards compatibility.

Examples

Run this code
# NOT RUN {
sc <- spark_connect(master = "local")
mtcars_tbl <- sdf_copy_to(sc, mtcars, name = "mtcars_tbl", overwrite = TRUE)

partitions <- mtcars_tbl %>%
  sdf_random_split(training = 0.7, test = 0.3, seed = 1111)

mtcars_training <- partitions$training
mtcars_test <- partitions$test

# for multiclass classification
rf_model <- mtcars_training %>%
  ml_random_forest(cyl ~ ., type = "classification")

pred <- ml_predict(rf_model, mtcars_test)

ml_multiclass_classification_evaluator(pred)

# for regression
rf_model <- mtcars_training %>%
  ml_random_forest(cyl ~ ., type = "regression")

pred <- ml_predict(rf_model, mtcars_test)

ml_regression_evaluator(pred, label_col = "cyl")

# for binary classification
rf_model <- mtcars_training %>%
  ml_random_forest(am ~ gear + carb, type = "classification")

pred <- ml_predict(rf_model, mtcars_test)

ml_binary_classification_evaluator(pred)
# }
# NOT RUN {
# }

Run the code above in your browser using DataLab