This layer enables the use of Flax components in the form of
flax.linen.Module
instances within Keras when using JAX as the backend for Keras.
The module method to use for the forward pass can be specified via the
method
argument and is __call__
by default. This method must take the
following arguments with these exact names:
self
if the method is bound to the module, which is the case for the
default of __call__
, and module
otherwise to pass the module.
inputs
: the inputs to the model, a JAX array or a PyTree
of arrays.
training
(optional): an argument specifying if we're in training mode
or inference mode, TRUE
is passed in training mode.
FlaxLayer
handles the non-trainable state of your model and required RNGs
automatically. Note that the mutable
parameter of
flax.linen.Module.apply()
is set to DenyList(["params"])
, therefore making the assumption that all
the variables outside of the "params" collection are non-trainable weights.
This example shows how to create a FlaxLayer
from a Flax Module
with
the default __call__
method and no training argument:
# keras3::use_backend("jax")
# py_install("flax", "r-keras")if(config_backend() == "jax" &&
reticulate::py_module_available("flax")) {
flax <- import("flax")
MyFlaxModule(flax$linen$Module) %py_class% {
`__call__` <- flax$linen$compact(\(self, inputs) {
inputs |>
(flax$linen$Conv(features = 32L, kernel_size = tuple(3L, 3L)))() |>
flax$linen$relu() |>
flax$linen$avg_pool(window_shape = tuple(2L, 2L),
strides = tuple(2L, 2L)) |>
# flatten all except batch_size axis
(\(x) x$reshape(tuple(x$shape[[1]], -1L)))() |>
(flax$linen$Dense(features = 200L))() |>
flax$linen$relu() |>
(flax$linen$Dense(features = 10L))() |>
flax$linen$softmax()
})
}
# typical usage:
input <- keras_input(c(28, 28, 3))
output <- input |>
layer_flax_module_wrapper(MyFlaxModule())
model <- keras_model(input, output)
# to instantiate the layer before composing:
flax_module <- MyFlaxModule()
keras_layer <- layer_flax_module_wrapper(module = flax_module)
input <- keras_input(c(28, 28, 3))
output <- input |>
keras_layer()
model <- keras_model(input, output)
}
This example shows how to wrap the module method to conform to the required signature. This allows having multiple input arguments and a training argument that has a different name and values. This additionally shows how to use a function that is not bound to the module.
flax <- import("flax")MyFlaxModule(flax$linen$Module) \%py_class\% {
forward <-
flax$linen$compact(\(self, inputs1, input2, deterministic) {
# do work ....
outputs # return
})
}
my_flax_module_wrapper <- function(module, inputs, training) {
c(input1, input2) \%<-\% inputs
module$forward(input1, input2,!training)
}
flax_module <- MyFlaxModule()
keras_layer <- layer_flax_module_wrapper(module = flax_module,
method = my_flax_module_wrapper)
layer_flax_module_wrapper(object, module, method = NULL, variables = NULL, ...)
The return value depends on the value provided for the first argument.
If object
is:
a keras_model_sequential()
, then the layer is added to the sequential model
(which is modified in place). To enable piping, the sequential model is also
returned, invisibly.
a keras_input()
, then the output tensor from calling layer(input)
is returned.
NULL
or missing, then a Layer
instance is returned.
Object to compose the layer with. A tensor, array, or sequential model.
An instance of flax.linen.Module
or subclass.
The method to call the model. This is generally a method in the
Module
. If not provided, the __call__
method is used. method
can also be a function not defined in the Module
, in which case it
must take the Module
as the first argument. It is used for both
Module.init
and Module.apply
. Details are documented in the
method
argument of flax.linen.Module.apply()
.
A dict
(named R list) containing all the variables of the module in the
same format as what is returned by flax.linen.Module.init()
.
It should contain a "params"
key and, if applicable, other keys for
collections of variables for non-trainable state. This allows
passing trained parameters and learned non-trainable state or
controlling the initialization. If NULL
is passed, the module's
init
function is called at build time to initialize the variables
of the model.
For forward/backward compatability.
Other wrapping layers:
layer_jax_model_wrapper()
layer_torch_module_wrapper()
Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_auto_contrast()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_equalization()
layer_feature_space()
layer_flatten()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_num_bounding_boxes()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_mix_up()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_solarization()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_stft_spectrogram()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()