if (FALSE) {
### Demonstrating usage after fitting models with RStan
library(rstan)
# generate fake data from N(0,1).
N <- 100
y <- rnorm(N, 0, 1)
# Suppose we have three models: N(-1, sigma), N(0.5, sigma) and N(0.6,sigma).
stan_code <- "
data {
int N;
vector[N] y;
real mu_fixed;
}
parameters {
real sigma;
}
model {
sigma ~ exponential(1);
y ~ normal(mu_fixed, sigma);
}
generated quantities {
vector[N] log_lik;
for (n in 1:N) log_lik[n] = normal_lpdf(y[n]| mu_fixed, sigma);
}"
mod <- stan_model(model_code = stan_code)
fit1 <- sampling(mod, data=list(N=N, y=y, mu_fixed=-1))
fit2 <- sampling(mod, data=list(N=N, y=y, mu_fixed=0.5))
fit3 <- sampling(mod, data=list(N=N, y=y, mu_fixed=0.6))
model_list <- list(fit1, fit2, fit3)
log_lik_list <- lapply(model_list, extract_log_lik)
# optional but recommended
r_eff_list <- lapply(model_list, function(x) {
ll_array <- extract_log_lik(x, merge_chains = FALSE)
relative_eff(exp(ll_array))
})
# stacking method:
wts1 <- loo_model_weights(
log_lik_list,
method = "stacking",
r_eff_list = r_eff_list,
optim_control = list(reltol=1e-10)
)
print(wts1)
# can also pass a list of psis_loo objects to avoid recomputing loo
loo_list <- lapply(1:length(log_lik_list), function(j) {
loo(log_lik_list[[j]], r_eff = r_eff_list[[j]])
})
wts2 <- loo_model_weights(
loo_list,
method = "stacking",
optim_control = list(reltol=1e-10)
)
all.equal(wts1, wts2)
# can provide names to be used in the results
loo_model_weights(setNames(loo_list, c("A", "B", "C")))
# pseudo-BMA+ method:
set.seed(1414)
loo_model_weights(loo_list, method = "pseudobma")
# pseudo-BMA method (set BB = FALSE):
loo_model_weights(loo_list, method = "pseudobma", BB = FALSE)
# calling stacking_weights or pseudobma_weights directly
lpd1 <- loo(log_lik_list[[1]], r_eff = r_eff_list[[1]])$pointwise[,1]
lpd2 <- loo(log_lik_list[[2]], r_eff = r_eff_list[[2]])$pointwise[,1]
lpd3 <- loo(log_lik_list[[3]], r_eff = r_eff_list[[3]])$pointwise[,1]
stacking_weights(cbind(lpd1, lpd2, lpd3))
pseudobma_weights(cbind(lpd1, lpd2, lpd3))
pseudobma_weights(cbind(lpd1, lpd2, lpd3), BB = FALSE)
}
Run the code above in your browser using DataLab