This layer enables the use of JAX components within Keras when using JAX as the backend for Keras.
layer_jax_model_wrapper(
object,
call_fn,
init_fn = NULL,
params = NULL,
state = NULL,
seed = NULL,
...
)
The return value depends on the value provided for the first argument.
If object
is:
a keras_model_sequential()
, then the layer is added to the sequential model
(which is modified in place). To enable piping, the sequential model is also
returned, invisibly.
a keras_input()
, then the output tensor from calling layer(input)
is returned.
NULL
or missing, then a Layer
instance is returned.
Object to compose the layer with. A tensor, array, or sequential model.
The function to call the model. See description above for the list of arguments it takes and the outputs it returns.
the function to call to initialize the model. See description
above for the list of arguments it takes and the outputs it returns.
If NULL
, then params
and/or state
must be provided.
A PyTree
containing all the model trainable parameters. This
allows passing trained parameters or controlling the initialization.
If both params
and state
are NULL
, init_fn()
is called at
build time to initialize the trainable parameters of the model.
A PyTree
containing all the model non-trainable state. This
allows passing learned state or controlling the initialization. If
both params
and state
are NULL
, and call_fn()
takes a state
argument, then init_fn()
is called at build time to initialize the
non-trainable state of the model.
Seed for random number generator. Optional.
For forward/backward compatability.
This layer accepts JAX models in the form of a function, call_fn()
, which
must take the following arguments with these exact names:
params
: trainable parameters of the model.
state
(optional): non-trainable state of the model. Can be omitted if
the model has no non-trainable state.
rng
(optional): a jax.random.PRNGKey
instance. Can be omitted if the
model does not need RNGs, neither during training nor during inference.
inputs
: inputs to the model, a JAX array or a PyTree
of arrays.
training
(optional): an argument specifying if we're in training mode
or inference mode, TRUE
is passed in training mode. Can be omitted if
the model behaves the same in training mode and inference mode.
The inputs
argument is mandatory. Inputs to the model must be provided via
a single argument. If the JAX model takes multiple inputs as separate
arguments, they must be combined into a single structure, for instance in a
tuple()
or a dict()
.
The initialization of the params
and state
of the model can be handled
by this layer, in which case the init_fn()
argument must be provided. This
allows the model to be initialized dynamically with the right shape.
Alternatively, and if the shape is known, the params
argument and
optionally the state
argument can be used to create an already initialized
model.
The init_fn()
function, if provided, must take the following arguments with
these exact names:
rng
: a jax.random.PRNGKey
instance.
inputs
: a JAX array or a PyTree
of arrays with placeholder values to
provide the shape of the inputs.
training
(optional): an argument specifying if we're in training mode
or inference mode. True
is always passed to init_fn
. Can be omitted
regardless of whether call_fn
has a training
argument.
For JAX models that have non-trainable state:
call_fn()
must have a state
argument
call_fn()
must return a tuple()
containing the outputs of the model and
the new non-trainable state of the model
init_fn()
must return a tuple()
containing the initial trainable params of
the model and the initial non-trainable state of the model.
This code shows a possible combination of call_fn()
and init_fn()
signatures
for a model with non-trainable state. In this example, the model has a
training
argument and an rng
argument in call_fn()
.
stateful_call <- function(params, state, rng, inputs, training) {
outputs <- ....
new_state <- ....
tuple(outputs, new_state)
}stateful_init <- function(rng, inputs) {
initial_params <- ....
initial_state <- ....
tuple(initial_params, initial_state)
}
For JAX models with no non-trainable state:
call_fn()
must not have a state
argument
call_fn()
must return only the outputs of the model
init_fn()
must return only the initial trainable params of the model.
This code shows a possible combination of call_fn()
and init_fn()
signatures
for a model without non-trainable state. In this example, the model does not
have a training
argument and does not have an rng
argument in call_fn()
.
stateful_call <- function(pparams, inputs) {
outputs <- ....
outputs
}stateful_init <- function(rng, inputs) {
initial_params <- ....
initial_params
}
If a model has a different signature than the one required by JaxLayer
,
one can easily write a wrapper method to adapt the arguments. This example
shows a model that has multiple inputs as separate arguments, expects
multiple RNGs in a dict
, and has a deterministic
argument with the
opposite meaning of training
. To conform, the inputs are combined in a
single structure using a tuple
, the RNG is split and used the populate the
expected dict
, and the Boolean flag is negated:
jax <- import("jax")
my_model_fn <- function(params, rngs, input1, input2, deterministic) {
....
if (!deterministic) {
dropout_rng <- rngs$dropout
keep <- jax$random$bernoulli(dropout_rng, dropout_rate, x$shape)
x <- jax$numpy$where(keep, x / dropout_rate, 0)
....
}
....
return(outputs)
}my_model_wrapper_fn <- function(params, rng, inputs, training) {
c(input1, input2) %<-% inputs
c(rng1, rng2) %<-% jax$random$split(rng)
rngs <- list(dropout = rng1, preprocessing = rng2)
deterministic <- !training
my_model_fn(params, rngs, input1, input2, deterministic)
}
keras_layer <- layer_jax_model_wrapper(call_fn = my_model_wrapper_fn,
params = initial_params)
JaxLayer
enables the use of Haiku
components in the form of
haiku.Module
.
This is achieved by transforming the module per the Haiku pattern and then
passing module.apply
in the call_fn
parameter and module.init
in the
init_fn
parameter if needed.
If the model has non-trainable state, it should be transformed with
haiku.transform_with_state
.
If the model has no non-trainable state, it should be transformed with
haiku.transform
.
Additionally, and optionally, if the module does not use RNGs in "apply", it
can be transformed with
haiku.without_apply_rng
.
The following example shows how to create a JaxLayer
from a Haiku module
that uses random number generators via hk.next_rng_key()
and takes a
training positional argument:
# reticulate::py_install("haiku", "r-keras")
hk <- import("haiku")
MyHaikuModule(hk$Module) \%py_class\% { `__call__` <- \(self, x, training) {
x <- hk$Conv2D(32L, tuple(3L, 3L))(x)
x <- jax$nn$relu(x)
x <- hk$AvgPool(tuple(1L, 2L, 2L, 1L),
tuple(1L, 2L, 2L, 1L), "VALID")(x)
x <- hk$Flatten()(x)
x <- hk$Linear(200L)(x)
if (training)
x <- hk$dropout(rng = hk$next_rng_key(), rate = 0.3, x = x)
x <- jax$nn$relu(x)
x <- hk$Linear(10L)(x)
x <- jax$nn$softmax(x)
x
}
}
my_haiku_module_fn <- function(inputs, training) {
module <- MyHaikuModule()
module(inputs, training)
}
transformed_module <- hk$transform(my_haiku_module_fn)
keras_layer <-
layer_jax_model_wrapper(call_fn = transformed_module$apply,
init_fn = transformed_module$init)
Other wrapping layers:
layer_flax_module_wrapper()
layer_torch_module_wrapper()
Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_auto_contrast()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_equalization()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_num_bounding_boxes()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_mix_up()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_solarization()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_stft_spectrogram()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()