# \donttest{
predictors <- tf$cast( c(201,244, 47,287,203,58,210,202,198,158,165,201,157,
131,166,160,186,125,218,146),tf$float32)
obs <- tf$cast(c(592,401,583,402,495,173,479,504,510,416,393,442,317,311,400,
337,423,334,533,344),tf$float32)
y_sigma <- tf$cast(c(61,25,38,15,21,15,27,14,30,16,14,25,52,16,34,31,42,26,
16,22),tf$float32)
# Robust linear regression model
robust_lm <- tfd_joint_distribution_sequential(
list(
tfd_normal(loc = 0, scale = 1, name = "b0"),
tfd_normal(loc = 0, scale = 1, name = "b1"),
tfd_half_normal(5, name = "df"),
function(df, b1, b0)
tfd_independent(
tfd_student_t(
# Likelihood
df = tf$expand_dims(df, axis = -1L),
loc = tf$expand_dims(b0, axis = -1L) +
tf$expand_dims(b1, axis = -1L) * predictors[tf$newaxis, ],
scale = y_sigma,
name = "st"
), name = "ind")), validate_args = TRUE)
log_prob <-function(b0, b1, df) {robust_lm %>%
tfd_log_prob(list(b0, b1, df, obs))}
step_size0 <- Map(function(x) tf$cast(x, tf$float32), c(1, .2, .5))
number_of_steps <- 10
burnin <- 5
nchain <- 50
run_chain <- function() {
# random initialization of the starting postion of each chain
samples <- robust_lm %>% tfd_sample(nchain)
b0 <- samples[[1]]
b1 <- samples[[2]]
df <- samples[[3]]
# bijector to map constrained parameters to real
unconstraining_bijectors <- list(
tfb_identity(), tfb_identity(), tfb_exp())
trace_fn <- function(x, pkr) {
list(pkr$inner_results$inner_results$step_size,
pkr$inner_results$inner_results$log_accept_ratio)
}
nuts <- mcmc_no_u_turn_sampler(
target_log_prob_fn = log_prob,
step_size = step_size0
) %>%
mcmc_transformed_transition_kernel(bijector = unconstraining_bijectors) %>%
mcmc_dual_averaging_step_size_adaptation(
num_adaptation_steps = burnin,
step_size_setter_fn = function(pkr, new_step_size)
pkr$`_replace`(
inner_results = pkr$inner_results$`_replace`(step_size = new_step_size)),
step_size_getter_fn = function(pkr) pkr$inner_results$step_size,
log_accept_prob_getter_fn = function(pkr) pkr$inner_results$log_accept_ratio
)
nuts %>% mcmc_sample_chain(
num_results = number_of_steps,
num_burnin_steps = burnin,
current_state = list(b0, b1, df),
trace_fn = trace_fn)
}
run_chain <- tensorflow::tf_function(run_chain)
res <- run_chain()
# }
Run the code above in your browser using DataLab