Generate predicted labels / values for input data provided by input_fn()
.
# S3 method for tf_estimator
predict(
object,
input_fn,
checkpoint_path = NULL,
predict_keys = c("predictions", "classes", "class_ids", "logistic", "logits",
"probabilities"),
hooks = NULL,
as_iterable = FALSE,
simplify = TRUE,
yield_single_examples = TRUE,
...
)
A TensorFlow estimator.
An input function, typically generated by the input_fn()
helper function.
The path to a specific model checkpoint to be used for
prediction. If NULL
(the default), the latest checkpoint in model_dir
is used.
The types of predictions that should be produced, as an R list. When this argument is not specified (the default), all possible predicted values will be returned.
A list of R functions, to be used as callbacks inside the
training loop. By default, hook_history_saver(every_n_step = 10)
and
hook_progress_bar()
will be attached if not provided to save the metrics
history and create the progress bar.
Boolean; should a raw Python generator be returned? When
FALSE
(the default), the predicted values will be consumed from the
generator and returned as an R object.
Whether to simplify prediction results into a tibble
,
as opposed to a list. Defaults to TRUE
.
(Available since TensorFlow v1.7) If FALSE
,
yields the whole batch as returned by the model_fn
instead of decomposing
the batch into individual elements. This is useful if model_fn
returns some
tensors with first dimension not equal to the batch size.
Optional arguments passed on to the estimator's predict()
method.
Evaluated values of predictions
tensors.
ValueError: Could not find a trained model in model_dir.
ValueError: if batch length of predictions are not same. ValueError: If
there is a conflict between predict_keys
and predictions
. For example
if predict_keys
is not NULL
but EstimatorSpec.predictions
is not a
dict
.
Other custom estimator methods:
estimator_spec()
,
estimator()
,
evaluate.tf_estimator()
,
export_savedmodel.tf_estimator()
,
train.tf_estimator()