Learn R Programming

keras3 (version 1.3.0)

freeze_weights: Freeze and unfreeze weights

Description

Freeze weights in a model or layer so that they are no longer trainable.

Usage

freeze_weights(object, from = NULL, to = NULL, which = NULL)

unfreeze_weights(object, from = NULL, to = NULL, which = NULL)

Value

The input object with frozen weights is returned, invisibly. Note, object is modified in place, and the return value is only provided to make usage with the pipe convenient.

Arguments

object

Keras model or layer object

from

Layer instance, layer name, or layer index within model

to

Layer instance, layer name, or layer index within model

which

layer names, integer positions, layers, logical vector (of length(object$layers)), or a function returning a logical vector.

Examples

# instantiate a VGG16 model
conv_base <- application_vgg16(
  weights = "imagenet",
  include_top = FALSE,
  input_shape = c(150, 150, 3)
)

# freeze it's weights freeze_weights(conv_base)

# Note the "Trainable" column conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 0 (0.00 B)
##  Non-trainable params: 14,714,688 (56.13 MB)


# create a composite model that includes the base + more layers
model <- keras_model_sequential(input_batch_shape = shape(conv_base$input)) |>
  conv_base() |>
  layer_flatten() |>
  layer_dense(units = 256, activation = "relu") |>
  layer_dense(units = 1, activation = "sigmoid")

# compile model |> compile( loss = "binary_crossentropy", optimizer = optimizer_rmsprop(learning_rate = 2e-5), metrics = c("accuracy") )

model

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 2,097,665 (8.00 MB)
##  Non-trainable params: 14,714,688 (56.13 MB)

print(model, expand_nested = TRUE)

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > input_layer            | (None, 150, 150, 3)   |          0 |   -   |
## | (InputLayer)                |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv1 (Conv2D)  | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv2 (Conv2D)  | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_pool            | (None, 75, 75, 64)    |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv1 (Conv2D)  | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv2 (Conv2D)  | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_pool            | (None, 37, 37, 128)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv1 (Conv2D)  | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv2 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv3 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_pool            | (None, 18, 18, 256)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv1 (Conv2D)  | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv2 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv3 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_pool            | (None, 9, 9, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv1 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv2 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv3 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_pool            | (None, 4, 4, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 2,097,665 (8.00 MB)
##  Non-trainable params: 14,714,688 (56.13 MB)

# unfreeze weights from "block5_conv1" on unfreeze_weights(conv_base, from = "block5_conv1")

# compile again since we froze or unfroze weights model |> compile( loss = "binary_crossentropy", optimizer = optimizer_rmsprop(learning_rate = 2e-5), metrics = c("accuracy") )

conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,079,424 (27.01 MB)
##  Non-trainable params: 7,635,264 (29.13 MB)

print(model, expand_nested = TRUE)

## Model: "sequential"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | vgg16 (Functional)          | (None, 4, 4, 512)     | 14,714,688 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > input_layer            | (None, 150, 150, 3)   |          0 |   -   |
## | (InputLayer)                |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv1 (Conv2D)  | (None, 150, 150, 64)  |      1,792 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_conv2 (Conv2D)  | (None, 150, 150, 64)  |     36,928 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block1_pool            | (None, 75, 75, 64)    |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv1 (Conv2D)  | (None, 75, 75, 128)   |     73,856 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_conv2 (Conv2D)  | (None, 75, 75, 128)   |    147,584 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block2_pool            | (None, 37, 37, 128)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv1 (Conv2D)  | (None, 37, 37, 256)   |    295,168 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv2 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_conv3 (Conv2D)  | (None, 37, 37, 256)   |    590,080 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block3_pool            | (None, 18, 18, 256)   |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv1 (Conv2D)  | (None, 18, 18, 512)   |  1,180,160 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv2 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_conv3 (Conv2D)  | (None, 18, 18, 512)   |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block4_pool            | (None, 9, 9, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv1 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv2 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_conv3 (Conv2D)  | (None, 9, 9, 512)     |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## |    > block5_pool            | (None, 4, 4, 512)     |          0 |   -   |
## | (MaxPooling2D)              |                       |            |       |
## +-----------------------------+-----------------------+------------+-------+
## | flatten (Flatten)           | (None, 8192)          |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | dense (Dense)               | (None, 256)           |  2,097,408 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | dense_1 (Dense)             | (None, 1)             |        257 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 16,812,353 (64.13 MB)
##  Trainable params: 9,177,089 (35.01 MB)
##  Non-trainable params: 7,635,264 (29.13 MB)


# freeze only the last 5 layers
freeze_weights(conv_base, from = -5)
conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,635,264 (29.13 MB)
##  Non-trainable params: 7,079,424 (27.01 MB)

# freeze only the last 5 layers, a different way
unfreeze_weights(conv_base, to = -6)
conv_base

## Model: "vgg16"
## +-----------------------------+-----------------------+------------+-------+
## | Layer (type)                | Output Shape          |    Param # | Trai… |
## +=============================+=======================+============+=======+
## | input_layer (InputLayer)    | (None, 150, 150, 3)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv1 (Conv2D)       | (None, 150, 150, 64)  |      1,792 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_conv2 (Conv2D)       | (None, 150, 150, 64)  |     36,928 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block1_pool (MaxPooling2D)  | (None, 75, 75, 64)    |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv1 (Conv2D)       | (None, 75, 75, 128)   |     73,856 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_conv2 (Conv2D)       | (None, 75, 75, 128)   |    147,584 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block2_pool (MaxPooling2D)  | (None, 37, 37, 128)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv1 (Conv2D)       | (None, 37, 37, 256)   |    295,168 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv2 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_conv3 (Conv2D)       | (None, 37, 37, 256)   |    590,080 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block3_pool (MaxPooling2D)  | (None, 18, 18, 256)   |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv1 (Conv2D)       | (None, 18, 18, 512)   |  1,180,160 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv2 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_conv3 (Conv2D)       | (None, 18, 18, 512)   |  2,359,808 |   Y   |
## +-----------------------------+-----------------------+------------+-------+
## | block4_pool (MaxPooling2D)  | (None, 9, 9, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv1 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv2 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_conv3 (Conv2D)       | (None, 9, 9, 512)     |  2,359,808 |   N   |
## +-----------------------------+-----------------------+------------+-------+
## | block5_pool (MaxPooling2D)  | (None, 4, 4, 512)     |          0 |   -   |
## +-----------------------------+-----------------------+------------+-------+
##  Total params: 14,714,688 (56.13 MB)
##  Trainable params: 7,635,264 (29.13 MB)
##  Non-trainable params: 7,079,424 (27.01 MB)


# Freeze only layers of a certain type, e.g, BatchNorm layers
batch_norm_layer_class_name <- class(layer_batch_normalization())[1]
is_batch_norm_layer <- function(x) inherits(x, batch_norm_layer_class_name)

model <- application_efficientnet_b0() freeze_weights(model, which = is_batch_norm_layer) # print(model)

# equivalent to: for(layer in model$layers) { if(is_batch_norm_layer(layer)) layer$trainable <- FALSE else layer$trainable <- TRUE }