Learn R Programming

keras (version 2.9.0)

freeze_weights: Freeze and unfreeze weights

Description

Freeze weights in a model or layer so that they are no longer trainable.

Usage

freeze_weights(object, from = NULL, to = NULL, which = NULL)

unfreeze_weights(object, from = NULL, to = NULL, which = NULL)

Arguments

object

Keras model or layer object

from

Layer instance, layer name, or layer index within model

to

Layer instance, layer name, or layer index within model

which

layer names, integer positions, layers, logical vector (of length(object$layers)), or a function returning a logical vector.

Examples

Run this code
if (FALSE) {
conv_base <- application_vgg16(
  weights = "imagenet",
  include_top = FALSE,
  input_shape = c(150, 150, 3)
)

# freeze it's weights
freeze_weights(conv_base)

conv_base

# create a composite model that includes the base + more layers
model <- keras_model_sequential() %>%
  conv_base() %>%
  layer_flatten() %>%
  layer_dense(units = 256, activation = "relu") %>%
  layer_dense(units = 1, activation = "sigmoid")

# compile
model %>% compile(
  loss = "binary_crossentropy",
  optimizer = optimizer_rmsprop(lr = 2e-5),
  metrics = c("accuracy")
)

model
print(model, expand_nested = TRUE)



# unfreeze weights from "block5_conv1" on
unfreeze_weights(conv_base, from = "block5_conv1")

# compile again since we froze or unfroze weights
model %>% compile(
  loss = "binary_crossentropy",
  optimizer = optimizer_rmsprop(lr = 2e-5),
  metrics = c("accuracy")
)

conv_base
print(model, expand_nested = TRUE)

# freeze only the last 5 layers
freeze_weights(conv_base, from = -5)
conv_base
# equivalently, also freeze only the last 5 layers
unfreeze_weights(conv_base, to = -6)
conv_base

# Freeze only layers of a certain type, e.g, BatchNorm layers
batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)

model <- application_efficientnet_b0()
freeze_weights(model, which = is_batch_norm_layer)
model
# equivalent to:
for(layer in model$layers) {
  if(is_batch_norm_layer(layer))
    layer$trainable <- FALSE
  else
    layer$trainable <- TRUE
}
}

Run the code above in your browser using DataLab