Learn R Programming

keras (version 2.9.0)

learning_rate_schedule_exponential_decay: A LearningRateSchedule that uses an exponential decay schedule

Description

A LearningRateSchedule that uses an exponential decay schedule

Usage

learning_rate_schedule_exponential_decay(
  initial_learning_rate,
  decay_steps,
  decay_rate,
  staircase = FALSE,
  ...,
  name = NULL
)

Arguments

initial_learning_rate

A scalar float32 or float64 Tensor or a R number. The initial learning rate.

decay_steps

A scalar int32 or int64 Tensor or an R number. Must be positive. See the decay computation above.

decay_rate

A scalar float32 or float64 Tensor or an R number. The decay rate.

staircase

Boolean. If TRUE decay the learning rate at discrete intervals.

...

For backwards and forwards compatibility

name

String. Optional name of the operation. Defaults to 'ExponentialDecay'.

Details

When training a model, it is often useful to lower the learning rate as the training progresses. This schedule applies an exponential decay function to an optimizer step, given a provided initial learning rate.

The schedule is a 1-arg callable that produces a decayed learning rate when passed the current optimizer step. This can be useful for changing the learning rate value across different invocations of optimizer functions. It is computed as:

decayed_learning_rate <- function(step)
  initial_learning_rate * decay_rate ^ (step / decay_steps)

If the argument staircase is TRUE, then step / decay_steps is an integer division (%/%) and the decayed learning rate follows a staircase function.

You can pass this schedule directly into a optimizer as the learning rate (see example) Example: When fitting a Keras model, decay every 100000 steps with a base of 0.96:

initial_learning_rate <- 0.1
lr_schedule <- learning_rate_schedule_exponential_decay(
    initial_learning_rate,
    decay_steps = 100000,
    decay_rate = 0.96,
    staircase = TRUE)

model %>% compile( optimizer= optimizer_sgd(learning_rate = lr_schedule), loss = 'sparse_categorical_crossentropy', metrics = 'accuracy')

model %>% fit(data, labels, epochs = 5)

See Also