This layer performs crosses of categorical features using the "hashing
trick". Conceptually, the transformation can be thought of as:
hash(concatenate(features)) %% num_bins
.
This layer currently only performs crosses of scalar inputs and batches of
scalar inputs. Valid input shapes are (batch_size, 1)
, (batch_size)
and
()
.
Note: This layer wraps tf.keras.layers.HashedCrossing
. It cannot
be used as part of the compiled computation graph of a model with
any backend other than TensorFlow.
It can however be used with any backend when running eagerly.
It can also always be used as part of an input preprocessing pipeline
with any backend (outside the model itself), which is how we recommend
to use this layer.
Note: This layer is safe to use inside a tfdatasets
pipeline
(independently of which backend you're using).
layer_hashed_crossing(
object,
num_bins,
output_mode = "int",
sparse = FALSE,
name = NULL,
dtype = NULL,
...
)
The return value depends on the value provided for the first argument.
If object
is:
a keras_model_sequential()
, then the layer is added to the sequential model
(which is modified in place). To enable piping, the sequential model is also
returned, invisibly.
a keras_input()
, then the output tensor from calling layer(input)
is returned.
NULL
or missing, then a Layer
instance is returned.
Object to compose the layer with. A tensor, array, or sequential model.
Number of hash bins.
Specification for the output of the layer. Values can be
"int"
, or "one_hot"
configuring the layer as follows:
"int"
: Return the integer bin indices directly.
"one_hot"
: Encodes each individual element in the input into an
array the same size as num_bins
, containing a 1 at the input's
bin index. Defaults to "int"
.
Boolean. Only applicable to "one_hot"
mode and only valid
when using the TensorFlow backend. If TRUE
, returns
a SparseTensor
instead of a dense Tensor
. Defaults to FALSE
.
String, name for the object
datatype (e.g., "float32"
).
Keyword arguments to construct a layer.
feat1 <- c('A', 'B', 'A', 'B', 'A') |> as.array()
feat2 <- c(101, 101, 101, 102, 102) |> as.integer() |> as.array()
Crossing two scalar features.
layer <- layer_hashed_crossing(num_bins = 5)
layer(list(feat1, feat2))
## tf.Tensor([1 4 1 1 3], shape=(5), dtype=int64)
Crossing and one-hotting two scalar features.
layer <- layer_hashed_crossing(num_bins = 5, output_mode = 'one_hot')
layer(list(feat1, feat2))
## tf.Tensor(
## [[0. 1. 0. 0. 0.]
## [0. 0. 0. 0. 1.]
## [0. 1. 0. 0. 0.]
## [0. 1. 0. 0. 0.]
## [0. 0. 0. 1. 0.]], shape=(5, 5), dtype=float32)
Other categorical features preprocessing layers:
layer_category_encoding()
layer_hashing()
layer_integer_lookup()
layer_string_lookup()
Other preprocessing layers:
layer_auto_contrast()
layer_category_encoding()
layer_center_crop()
layer_discretization()
layer_equalization()
layer_feature_space()
layer_hashing()
layer_integer_lookup()
layer_max_num_bounding_boxes()
layer_mel_spectrogram()
layer_mix_up()
layer_normalization()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_rescaling()
layer_resizing()
layer_solarization()
layer_stft_spectrogram()
layer_string_lookup()
layer_text_vectorization()
Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_auto_contrast()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_equalization()
layer_feature_space()
layer_flatten()
layer_flax_module_wrapper()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_jax_model_wrapper()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_num_bounding_boxes()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_mel_spectrogram()
layer_minimum()
layer_mix_up()
layer_multi_head_attention()
layer_multiply()
layer_normalization()
layer_permute()
layer_rand_augment()
layer_random_brightness()
layer_random_color_degeneration()
layer_random_color_jitter()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_grayscale()
layer_random_hue()
layer_random_posterization()
layer_random_rotation()
layer_random_saturation()
layer_random_sharpness()
layer_random_shear()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_solarization()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_stft_spectrogram()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()