if (FALSE) {
sc <- spark_connect(master = "local")
iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE)
# Create a pipeline
pipeline <- ml_pipeline(sc) %>%
ft_r_formula(Species ~ .) %>%
ml_random_forest_classifier()
# Specify hyperparameter grid
grid <- list(
random_forest = list(
num_trees = c(5, 10),
max_depth = c(5, 10),
impurity = c("entropy", "gini")
)
)
# Create the cross validator object
cv <- ml_cross_validator(
sc,
estimator = pipeline, estimator_param_maps = grid,
evaluator = ml_multiclass_classification_evaluator(sc),
num_folds = 3,
parallelism = 4
)
# Train the models
cv_model <- ml_fit(cv, iris_tbl)
# Print the metrics
ml_validation_metrics(cv_model)
}
Run the code above in your browser using DataLab