# generate data
set.seed(1)
n <- 100
x <- seq(0, 1, length.out = n)
z <- factor(sample(letters[1:3], size = n, replace = TRUE))
fun <- function(x, z){
mu <- c(-2, 0, 2)
zi <- as.integer(z)
fx <- mu[zi] + 3 * x + sin(2 * pi * x + mu[zi]*pi/4)
}
fx <- fun(x, z)
y <- fx + rnorm(n, sd = 0.5)
# define marginal knots
probs <- seq(0, 0.9, by = 0.1)
knots <- list(x = quantile(x, probs = probs),
z = letters[1:3])
# fit sm with specified knots
smod <- sm(y ~ x * z, knots = knots)
# get model "response" predictions
fit <- predict(smod)
mean((smod$fitted.values - fit)^2)
# get model "terms" predictions
trm <- predict(smod, type = "terms")
attr(trm, "constant")
head(trm)
mean((smod$fitted.values - rowSums(trm) - attr(trm, "constant"))^2)
# get predictions with "newdata" (= the original data)
fit <- predict(smod, newdata = data.frame(x = x, z = z))
mean((fit - smod$fitted.values)^2)
# get predictions and standard errors
fit <- predict(smod, se.fit = TRUE)
mean((fit$fit - smod$fitted.values)^2)
mean((fit$se.fit - smod$se.fit)^2)
# get 99% confidence interval
fit <- predict(smod, interval = "c", level = 0.99)
head(fit)
# get 99% prediction interval
fit <- predict(smod, interval = "p", level = 0.99)
head(fit)
# get predictions only for x main effect
fit <- predict(smod, newdata = data.frame(x = x),
se.fit = TRUE, terms = "x")
plotci(x, fit$fit, fit$se.fit)
# get predictions only for each group
fit.a <- predict(smod, newdata = data.frame(x = x, z = "a"), se.fit = TRUE)
fit.b <- predict(smod, newdata = data.frame(x = x, z = "b"), se.fit = TRUE)
fit.c <- predict(smod, newdata = data.frame(x = x, z = "c"), se.fit = TRUE)
# plot results (truth as dashed line)
plotci(x = x, y = fit.a$fit, se = fit.a$se.fit,
col = "red", col.ci = "pink", ylim = c(-6, 6))
lines(x, fun(x, rep(1, n)), lty = 2, col = "red")
plotci(x = x, y = fit.b$fit, se = fit.b$se.fit,
col = "blue", col.ci = "cyan", add = TRUE)
lines(x, fun(x, rep(2, n)), lty = 2, col = "blue")
plotci(x = x, y = fit.c$fit, se = fit.c$se.fit,
col = "darkgreen", col.ci = "lightgreen", add = TRUE)
lines(x, fun(x, rep(3, n)), lty = 2, col = "darkgreen")
# add legends
legend("bottomleft", legend = c("Truth", "Estimate", "CI"),
lty = c(2, 1, NA), lwd = c(1, 2, NA),
col = c("black", "black","gray80"),
pch = c(NA, NA, 15), pt.cex = 2, bty = "n")
legend("bottomright", legend = letters[1:3],
lwd = 2, col = c("red", "blue", "darkgreen"), bty = "n")
Run the code above in your browser using DataLab