# \donttest{
# Simulate latent count data for 500 spatial locations and 10 species
set.seed(0)
N_points <- 500
N_species <- 10
# Species-level intercepts (on the log scale)
alphas <- runif(N_species, 2, 2.25)
# Simulate a covariate and species-level responses to it
temperature <- rnorm(N_points)
betas <- runif(N_species, -0.5, 0.5)
# Simulate points uniformly over a space
lon <- runif(N_points, min = 150, max = 155)
lat <- runif(N_points, min = -20, max = -19)
# Set up spatial basis functions as a tensor product of lat and lon
sm <- mgcv::smoothCon(mgcv::te(lon, lat, k = 5),
data = data.frame(lon, lat),
knots = NULL)[[1]]
# The design matrix for this smooth is in the 'X' slot
des_mat <- sm$X
dim(des_mat)
# Function to generate a random covariance matrix where all variables
# have unit variance (i.e. diagonals are all 1)
random_Sigma = function(N){
L_Omega <- matrix(0, N, N);
L_Omega[1, 1] <- 1;
for (i in 2 : N) {
bound <- 1;
for (j in 1 : (i - 1)) {
L_Omega[i, j] <- runif(1, -sqrt(bound), sqrt(bound));
bound <- bound - L_Omega[i, j] ^ 2;
}
L_Omega[i, i] <- sqrt(bound);
}
Sigma <- L_Omega %*% t(L_Omega);
return(Sigma)
}
# Simulate a variance-covariance matrix for the correlations among
# basis coefficients
Sigma <- random_Sigma(N = NCOL(des_mat))
# Now simulate the species-level basis coefficients hierarchically, where
# spatial basis function correlations are a convex sum of a base correlation
# matrix and a species-level correlation matrix
basis_coefs <- matrix(NA, nrow = N_species, ncol = NCOL(Sigma))
base_field <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = Sigma)
for(t in 1:N_species){
corOmega <- (cov2cor(Sigma) * 0.7) +
(0.3 * cov2cor(random_Sigma(N = NCOL(des_mat))))
basis_coefs[t, ] <- mgcv::rmvn(1, mu = rep(0, NCOL(Sigma)), V = corOmega)
}
# Simulate the latent spatial processes
st_process <- do.call(rbind, lapply(seq_len(N_species), function(t){
data.frame(lat = lat,
lon = lon,
species = paste0('species_', t),
temperature = temperature,
process = alphas[t] +
betas[t] * temperature +
des_mat %*% basis_coefs[t,])
}))
# Now take noisy observations at some of the points (60)
obs_points <- sample(1:N_points, size = 60, replace = FALSE)
obs_points <- data.frame(lat = lat[obs_points],
lon = lon[obs_points],
site = 1:60)
# Keep only the process data at these points
st_process %>%
dplyr::inner_join(obs_points, by = c('lat', 'lon')) %>%
# now take noisy Poisson observations of the process
dplyr::mutate(count = rpois(NROW(.), lambda = exp(process))) %>%
dplyr::mutate(species = factor(species,
levels = paste0('species_', 1:N_species))) %>%
dplyr::group_by(lat, lon) -> dat
# View the count distributions for each species
library(ggplot2)
ggplot(dat, aes(x = count)) +
geom_histogram() +
facet_wrap(~ species, scales = 'free')
ggplot(dat, aes(x = lon, y = lat, col = log(count + 1))) +
geom_point(size = 2.25) +
facet_wrap(~ species, scales = 'free') +
scale_color_viridis_c() +
theme_classic()
# Inspect default priors for a joint species model with three spatial factors
priors <- get_mvgam_priors(formula = count ~
# Environmental model includes random slopes for
# a linear effect of temperature
s(species, bs = 're', by = temperature),
# Each factor estimates a different nonlinear spatial process, using
# 'by = trend' as in other mvgam State-Space models
factor_formula = ~ gp(lon, lat, k = 6, by = trend) - 1,
n_lv = 3,
# The data and grouping variables
data = dat,
unit = site,
species = species,
# Poisson observations
family = poisson())
head(priors)
# Fit a JSDM that estimates hierarchical temperature responses
# and that uses three latent spatial factors
mod <- jsdgam(formula = count ~
# Environmental model includes random slopes for a
# linear effect of temperature
s(species, bs = 're', by = temperature),
# Each factor estimates a different nonlinear spatial process, using
# 'by = trend' as in other mvgam State-Space models
factor_formula = ~ gp(lon, lat, k = 6, by = trend) - 1,
n_lv = 3,
# Change default priors for fixed random effect variances and
# factor P marginal deviations to standard normal
priors = c(prior(std_normal(),
class = sigma_raw),
prior(std_normal(),
class = `alpha_gp_trend(lon, lat):trendtrend1`),
prior(std_normal(),
class = `alpha_gp_trend(lon, lat):trendtrend2`),
prior(std_normal(),
class = `alpha_gp_trend(lon, lat):trendtrend3`)),
# The data and the grouping variables
data = dat,
unit = site,
species = species,
# Poisson observations
family = poisson(),
chains = 2,
silent = 2)
# Plot species-level intercept estimates
plot_predictions(mod, condition = 'species',
type = 'link')
# Plot species' hierarchical responses to temperature
plot_predictions(mod, condition = c('temperature', 'species', 'species'),
type = 'link')
# Plot posterior median estimates of the latent spatial factors
plot(mod, type = 'smooths', trend_effects = TRUE)
# Or using gratia, if you have it installed
if(requireNamespace('gratia', quietly = TRUE)){
gratia::draw(mod, trend_effects = TRUE)
}
# Calculate residual spatial correlations
post_cors <- residual_cor(mod)
names(post_cors)
# Look at lower and upper credible interval estimates for
# some of the estimated correlations
post_cors$cor[1:5, 1:5]
post_cors$cor_upper[1:5, 1:5]
post_cors$cor_lower[1:5, 1:5]
# A quick and dirty plot of the posterior median correlations
image(post_cors$cor)
# Posterior predictive checks and ELPD-LOO can ascertain model fit
pp_check(mod, type = "pit_ecdf_grouped",
group = "species", ndraws = 100)
loo(mod)
# Forecast log(counts) for entire region (site value doesn't matter as long
# as each spatial location has a different and unique site identifier);
# note this calculation takes a few minutes because of the need to calculate
# draws from the stochastic latent factors
newdata <- st_process %>%
dplyr::mutate(species = factor(species,
levels = paste0('species_',
1:N_species))) %>%
dplyr::group_by(lat, lon) %>%
dplyr::mutate(site = dplyr::cur_group_id()) %>%
dplyr::ungroup()
preds <- predict(mod, newdata = newdata)
# Plot the median log(count) predictions on a grid
newdata$log_count <- preds[,1]
ggplot(newdata, aes(x = lon, y = lat, col = log_count)) +
geom_point(size = 1.5) +
facet_wrap(~ species, scales = 'free') +
scale_color_viridis_c() +
theme_classic()
# }
Run the code above in your browser using DataLab