Learn R Programming

tfdatasets (version 2.17.0)

dataset_rejection_resample: A transformation that resamples a dataset to a target distribution.

Description

A transformation that resamples a dataset to a target distribution.

Usage

dataset_rejection_resample(
  dataset,
  class_func,
  target_dist,
  initial_dist = NULL,
  seed = NULL,
  name = NULL
)

Value

A tf.Dataset

Arguments

dataset

A tf.Dataset

class_func

A function mapping an element of the input dataset to a scalar tf.int32 tensor. Values should be in [0, num_classes).

target_dist

A floating point type tensor, shaped [num_classes].

initial_dist

(Optional.) A floating point type tensor, shaped [num_classes]. If not provided, the true class distribution is estimated live in a streaming fashion.

seed

(Optional.) Integer seed for the resampler.

name

(Optional.) A name for the tf.data operation.

Examples

Run this code
if (FALSE) {
initial_dist <- c(.5, .5)
target_dist <- c(.6, .4)
num_classes <- length(initial_dist)
num_samples <- 100000
data <- sample.int(num_classes, num_samples, prob = initial_dist, replace = TRUE)
dataset <- tensor_slices_dataset(data)
tally <- c(0, 0)
`add<-` <- function (x, value) x + value
# tfautograph::autograph({
#   for(i in dataset)
#     add(tally[as.numeric(i)]) <- 1
# })
dataset %>%
  as_array_iterator() %>%
  iterate(function(i) {
    add(tally[i]) <<- 1
  }, simplify = FALSE)
# The value of `tally` will be close to c(50000, 50000) as
# per the `initial_dist` distribution.
tally # c(50287, 49713)

tally <- c(0, 0)
dataset %>%
  dataset_rejection_resample(
    class_func = function(x) (x-1) %% 2,
    target_dist = target_dist,
    initial_dist = initial_dist
  ) %>%
  as_array_iterator() %>%
  iterate(function(element) {
    names(element) <- c("class_id", "i")
    add(tally[element$i]) <<- 1
  }, simplify = FALSE)
# The value of tally will be now be close to c(75000, 50000)
# thus satisfying the target_dist distribution.
tally # c(74822, 49921)
}

Run the code above in your browser using DataLab