Train an estimator on a set of input data provides by the input_fn()
.
# S3 method for tf_estimator
train(
object,
input_fn,
steps = NULL,
hooks = NULL,
max_steps = NULL,
saving_listeners = NULL,
...
)
A TensorFlow estimator.
An input function, typically generated by the input_fn()
helper function.
The number of steps for which the model should be trained on
this particular train()
invocation. If NULL
(the default), this
function will either train forever, or until the supplied input_fn()
has
provided all available data.
A list of R functions, to be used as callbacks inside the
training loop. By default, hook_history_saver(every_n_step = 10)
and
hook_progress_bar()
will be attached if not provided to save the metrics
history and create the progress bar.
The total number of steps for which the model should be
trained. If set, steps
must be NULL
. If the estimator has already been
trained a total of max_steps
times, then no training will be performed.
(Available since TensorFlow v1.4) A list of
CheckpointSaverListener
objects used for callbacks that run immediately
before or after checkpoint savings.
Optional arguments, passed on to the estimator's train()
method.
A data.frame of the training loss history.
Other custom estimator methods:
estimator_spec()
,
estimator()
,
evaluate.tf_estimator()
,
export_savedmodel.tf_estimator()
,
predict.tf_estimator()