data(alligator)
if (FALSE) {
library(nnet)
## nnet::multinom Multinomial logit model fit with neural nets
fit <- multinom(food ~ lake+size+sex, data=alligator, weights=count)
summary(fit$fitted.values)
## 1=bird, 2=fish, 3=invert, 4=other, 5=reptile
(L=length(alligator$count))
(N=sum(alligator$count))
y.train=integer(N)
x.train=matrix(nrow=N, ncol=3)
x.test=matrix(nrow=L, ncol=3)
k=1
for(i in 1:L) {
x.test[i, ]=as.integer(
c(alligator$lake[i], alligator$size[i], alligator$sex[i]))
if(alligator$count[i]>0)
for(j in 1:alligator$count[i]) {
y.train[k]=as.integer(alligator$food[i])
x.train[k, ]=as.integer(
c(alligator$lake[i], alligator$size[i], alligator$sex[i]))
k=k+1
}
}
table(y.train)
##test mbart with token run to ensure installation works
set.seed(99)
check = mbart(x.train, y.train, nskip=1, ndpost=1)
set.seed(99)
check = mbart(x.train, y.train, nskip=1, ndpost=1)
post=mbart(x.train, y.train, x.test)
##post=mc.mbart(x.train, y.train, x.test, mc.cores=8, seed=99)
##check=predict(post, x.test, mc.cores=8)
##print(cor(post$prob.test.mean, check$prob.test.mean)^2)
par(mfrow=c(3, 2))
K=5
for(j in 1:5) {
h=seq(j, L*K, K)
print(cor(fit$fitted.values[ , j], post$prob.test.mean[h])^2)
plot(fit$fitted.values[ , j], post$prob.test.mean[h],
xlim=0:1, ylim=0:1,
xlab=paste0('NN: Est. Prob. j=', j),
ylab=paste0('BART: Est. Prob. j=', j))
abline(a=0, b=1)
}
par(mfrow=c(1, 1))
L=16
x.test=matrix(nrow=L, ncol=3)
k=1
for(size in 1:2)
for(sex in 1:2)
for(lake in 1:4) {
x.test[k, ]=c(lake, size, sex)
k=k+1
}
x.test
## two sizes: 1=large: >2.3m, 2=small: <=2.3m
pred=predict(post, x.test)
##pred=predict(post, x.test, mc.cores=8)
ndpost=nrow(pred$prob.test)
size.test=matrix(nrow=ndpost, ncol=K*2)
for(i in 1:K) {
j=seq(i, L*K/2, K) ## large
size.test[ , i]=apply(pred$prob.test[ , j], 1, mean)
j=j+L*K/2 ## small
size.test[ , i+K]=apply(pred$prob.test[ , j], 1, mean)
}
size.test.mean=apply(size.test, 2, mean)
size.test.025=apply(size.test, 2, quantile, probs=0.025)
size.test.975=apply(size.test, 2, quantile, probs=0.975)
plot(factor(1:K, labels=c('bird', 'fish', 'invert', 'other', 'reptile')),
rep(1, K), col=1:K, type='n', lwd=1, lty=0,
xlim=c(1, K), ylim=c(0, 0.5), ylab='Prob.',
sub="Multinomial BART\nFriedman's partial dependence function")
points(1:K, size.test.mean[1:K+K], col=1)
lines(1:K, size.test.025[1:K+K], col=1, lty=2)
lines(1:K, size.test.975[1:K+K], col=1, lty=2)
points(1:K, size.test.mean[1:K], col=2)
lines(1:K, size.test.025[1:K], col=2, lty=2)
lines(1:K, size.test.975[1:K], col=2, lty=2)
## legend('topright', legend=c('Small', 'Large'),
## pch=1, col=1:2)
}
Run the code above in your browser using DataLab