## Lecture Slides 9 & Bayes Rules! Chapters 7 & 6 R Examples 
## Enhanced for Efficiency

# Load packages
library(bayesrules)
library(tidyverse)
library(mvtnorm)      # For multivariate normal sampling
library(matrixStats)   # For efficient matrix operations
library(coda)          # For autocorrelation and diagnostics
library(MCMCpack)     # For HPD interval calculation
library(TeachingDemos) #for HPD interval (empirical)
library(MASS)  # For generating true bivariate normal samples using mvrnorm()
library(purrr) # For more efficient iteration

#Set Working Directory to Source File Location
library("rstudioapi")  # Load rstudioapi package
setwd(dirname(getActiveDocumentContext()$path)) # Set working directory to source file location
#getwd() 

# Custom print function for intervals (confidence and credible intervals), rounded to r decimals
print_int <- function(interval,r=3) {
  interval <- round(interval,r)
  return(paste0("(", interval[1], ",", interval[2], ")"))
}

####################################################
##### Metropolis-Hastings Algorithm - Example 0
####################################################

# Set seed for reproducibility
set.seed(123)

# Define the target distribution: Standard normal
tar_dens <- function(theta) dnorm(theta, mean = 0, sd = 1)

# Metropolis-Hastings Algorithm with uniform proposal
metropolis_hastings <- function(N, pro_width = 2, init_theta = 0) {
  samples <- numeric(N)  # Pre-allocate for efficiency
  samples[1] <- init_theta  # Initial value
  theta_curr <- init_theta
  
  for (i in 2:N) {
    # Propose a new value from Uniform(theta - pro_width, theta + pro_width)
    theta_pro <- runif(1, min = theta_curr - pro_width, max = theta_curr + pro_width)
    
    # Calculate acceptance ratio
    acc_ratio <- tar_dens(theta_pro) / tar_dens(theta_curr)
    
    # Accept or reject the proposal
    if (runif(1) < acc_ratio) {
      theta_curr <- theta_pro
    }
    
    samples[i] <- theta_curr
  }
  return(samples)
}

# Generate samples using M-H algorithm
N <- 10000
sam_mh <- metropolis_hastings(N)

# Generate samples from true standard normal for comparison
rnorm_sam <- rnorm(N)

# Plot the density of M-H samples and true standard normal samples
df <- data.frame(
  value = c(sam_mh, rnorm_sam),
  source = rep(c("M-H Samples", "rnorm Samples"), each = N)
)

ggplot(df, aes(x = value, color = source)) +
  geom_density(linewidth=2) +
  labs(title = "Density Comparison of M-H and rnorm Samples",
       x = "Value", y = "Density") +
  theme_minimal()

# Calculate and print acceptance rate
acc_rate <- mean(diff(sam_mh) != 0)
cat("Acceptance rate:", round(acc_rate, 4), "\n")

# Autocorrelation plot for M-H samples
acf(sam_mh, xlab = "Lag", ylab = "Autocorrelation",
    main = "Autocorrelation Plot for M-H Samples")

# Thinning the chain by taking every 3rd sample
sam_thin <- sam_mh[seq(1, N, by = 3)]

# Posterior summary: median and 95% credible interval
post_med <- median(sam_thin)
cred_int <- quantile(sam_thin, probs = c(0.025, 0.975))
cat("Posterior median:", post_med, "\n")
cat("95% credible interval:", cred_int, "\n")

# Trace plot of the thinned samples
plot(sam_thin, type = 'l', xlab = "Iteration", ylab = "Sampled Values",
     main = "Trace Plot of Metropolis-Hastings Samples")

# Geweke diagnostic for checking convergence
gew_res <- geweke.diag(sam_thin)

# Geweke diagnostic z-statistic and p-value
z_stat <- gew_res[[1]]
p_val <- 2 * pnorm(abs(z_stat), lower.tail = FALSE)
cat("Geweke diagnostic z-statistic:", z_stat, "\n")
cat("p-value for Geweke diagnostic:", p_val, "\n")

####################################################
##### Metropolis-Hastings Algorithm - Example 1, (Sparrow data)
####################################################

# Read the sparrow data
sparrow <- read.table("../STAT7630-Data/sparrowdata.txt", header = TRUE)

# Extract the data
y <- sparrow[, 1]
x <- sparrow[, 2]
xsq <- x^2

# Construct the design matrix
X <- cbind(1, x, xsq)

# Prior parameters
p <- ncol(X)
beta_pri_mean <- rep(0, p)
beta_pri_sd <- rep(10, p)

# Proposal covariance matrix
prop_cov_mat <- var(log(y + 1/2)) * solve(crossprod(X))

# MCMC settings
N <- 10000

# Initialize variables
beta_cur <- rep(0, p)
acs <- 0  # acceptance counter
beta_vals <- matrix(NA, nrow = N, ncol = p)

# Function to calculate the log posterior
log_post <- function(beta, X, y, beta_mean, beta_sd) {
  log_lik <- sum(dpois(y, exp(X %*% beta), log = TRUE))
  log_pri <- sum(dnorm(beta, mean = beta_mean, sd = beta_sd, log = TRUE))
  return(log_lik + log_pri)
}

# Metropolis-Hastings algorithm
for (i in seq_len(N)) {
  beta_pro <- t(rmvnorm(1, beta_cur, prop_cov_mat))
  
  log_acc_ratio <- log_post(beta_pro, X, y, beta_pri_mean, beta_pri_sd) - 
    log_post(beta_cur, X, y, beta_pri_mean, beta_pri_sd)
  
  if ( log(runif(1)) < log_acc_ratio ) {
    beta_cur <- beta_pro
    acs <- acs + 1
  }
  
  beta_vals[i, ] <- beta_cur
}

# Acceptance rate
ac_rate <- acs / N
cat("Acceptance rate:", ac_rate, "\n")

# plot autocorrelation values for beta_0
acf(beta_vals[,1],xlab="lag",ylab="acf values", 
    main=expression(paste("autocorrelation plot for ", beta[0])))  
# plot autocorrelation values for beta_1
acf(beta_vals[,2],xlab="lag",ylab="acf values", 
    main=expression(paste("autocorrelation plot for ", beta[1])))  
# plot autocorrelation values for beta_0
acf(beta_vals[,3],xlab="lag",ylab="acf values", 
    main=expression(paste("autocorrelation plot for ", beta[2])))  

# Seems to be an issue with serial dependence.
# Thinning out the sampled values by taking every 10th row:
# Thinning the sampled values
thin <- 10
beta_vals_thin <- beta_vals[seq(1, N, by = thin), ]

# Posterior summary: Posterior medians and 95% credible intervals

# Efficiently compute medians using colMedians from matrixStats
post_meds <- colMedians(beta_vals_thin)

# 95% quantile-based intervals
quant_ints <- apply(beta_vals_thin, 2, quantile, probs = c(0.025, 0.975))

# HPD intervals
hpd_ints <- apply(beta_vals_thin, 2, function(beta_col) emp.hpd(beta_col, conf = 0.95))

# Display the results
cat("Posterior Medians:\n", post_meds, "\n")

cat("95% Quantile-based intervals:\n")
for (i in 1:ncol(beta_vals_thin)) {
  cat("beta", i-1, ": ", print_int(quant_ints[, i]), "\n", sep="")
}

cat("95% HPD intervals:\n")
for (i in 1:ncol(beta_vals_thin)) {
  cat("beta", i-1, ": ", print_int(hpd_ints[, i]), "\n", sep="")
}

# Plot posterior median for expected offspring for ages 1 to 6
myX <- cbind(1, 1:6, (1:6)^2)

# Posterior median-based predictions
yhat <- exp(myX %*% post_meds)

# Plot of predicted offspring based on the model
plot(1:6, yhat, type = 'b', xlab = 'age', ylab = 'expected # of offspring', 
     main = "Posterior median for expected # of offspring")

# Trace plots for the sampled beta_0 and beta_1 values
par(mfrow = c(3, 1))  # Set up for two plots in one window
plot(beta_vals_thin[, 1], type = 'l', 
     main = expression(paste("Trace Plot for ",beta[0])), ylab = expression(beta[0]))
plot(beta_vals_thin[, 2], type = 'l', 
     main = expression(paste("Trace Plot for ",beta[1])), ylab = expression(beta[1]))
plot(beta_vals_thin[, 3], type = 'l', 
     main = expression(paste("Trace Plot for ",beta[3])), ylab = expression(beta[2]))
par(mfrow = c(1, 1)) 

# Geweke diagnostic to check convergence of MCMC chain
gew_res <- geweke.diag(beta_vals_thin[, 1])
# Geweke z-statistic for beta0
gew_stat <- gew_res[[1]]
# Compute the associated p-value
gew_p_val <- 2 * pnorm(abs(gew_stat), lower.tail = FALSE)
cat("Geweke diagnostic z-statistic for beta_0:\n", gew_stat, "\n")
cat("Geweke p-value for beta_0 convergence check:\n", gew_p_val, "\n")

# Geweke z-statistic for beta1
gew_stat <- gew_res[[2]]
gew_p_val <- 2 * pnorm(abs(gew_stat), lower.tail = FALSE)
cat("Geweke diagnostic z-statistic for beta_1:\n", gew_stat, "\n")
cat("Geweke p-value for beta_0 convergence check:\n", gew_p_val, "\n")

# Geweke z-statistic for beta2
gew_stat <- gew_res[[2]]
gew_p_val <- 2 * pnorm(abs(gew_stat), lower.tail = FALSE)
cat("Geweke diagnostic z-statistic for beta_2:\n", gew_stat, "\n")
cat("Geweke p-value for beta_0 convergence check:\n", gew_p_val, "\n")

####################################################
##### Gibbs Sampling Algorithm - sig_
####################################################
# Parameters for the bivariate normal distribution
params <- list(
  mu_x = 0,    # Mean of X
  mu_y = 0,    # Mean of Y
  sigx = 1, # Standard deviation of X
  sigy = 1, # Standard deviation of Y
  rho = 0.8,   # Correlation between X and Y
  N = 10000    # Number of samples
)

# Conditional variances
sigx_y <- function() params$sigx * sqrt(1 - params$rho^2)
sigy_x <- function() params$sigy * sqrt(1 - params$rho^2)

# Gibbs Sampling Algorithm (with vectorized updates)
gibbs_sampler <- function(N, init_x = 0, init_y = 0) {
  sam_x <- numeric(N)
  sam_y <- numeric(N)
  sam_x[1] <- init_x
  sam_y[1] <- init_y
  
  for (t in 2:N) {
    # Sample X given Y
    meanx_y <- params$mu_x + params$rho * (sam_y[t-1] - params$mu_y) * (params$sigx / params$sigy)
    sam_x[t] <- rnorm(1, mean = meanx_y, sd = sigx_y())
    
    # Sample Y given X
    meany_x <- params$mu_y + params$rho * (sam_x[t] - params$mu_x) * (params$sigy / params$sigx)
    sam_y[t] <- rnorm(1, mean = meany_x, sd = sigy_x())
  }
  
  return(list(X = sam_x, Y = sam_y))
}

# Run Gibbs Sampling
set.seed(123)
gibbs_sam <- gibbs_sampler(params$N)

# Generate true bivariate normal samples for comparison
Sigma <- matrix(c(params$sigx^2, params$rho * params$sigx * params$sigy, 
                  params$rho * params$sigx * params$sigy, params$sigy^2), nrow = 2)
true_sam <- mvrnorm(params$N, mu = c(params$mu_x, params$mu_y), Sigma = Sigma)

# Combine Gibbs samples and true samples into a single data frame
df <- function(sam_gibbs, sam_true, label) {
  data.frame(
    value = c(sam_gibbs, sam_true),
    source = rep(c("Gibbs Samples", "True Samples"), each = params$N),
    variable = label
  )
}

df_x <- df(gibbs_sam$X, true_sam[, 1], "X")
df_y <- df(gibbs_sam$Y, true_sam[, 2], "Y")

# Plot density for X and Y in one function
plot_dens <- function(df, variable) {
  ggplot(df[df$variable == variable, ], aes(x = value, color = source)) +
    geom_density() +
    labs(title = paste("Density Comparison for", variable, ": Gibbs vs True Bivariate Normal"),
         x = variable, y = "Density") +
    theme_minimal()
}

# Density plots for X and Y
plot_dens(df_x, "X")
plot_dens(df_y, "Y")

# Calculate autocorrelation for Gibbs samples (for both X and Y)
acf(gibbs_sam$X, xlab = "Lag", ylab = "ACF", 
    main = "Autocorrelation Plot for X (Gibbs Samples)")
acf(gibbs_sam$Y, xlab = "Lag", ylab = "ACF", 
    main = "Autocorrelation Plot for Y (Gibbs Samples)")

# Thinning the chain by taking every 3rd sample
thin_sam <- function(samples) {
  samples[seq(1, params$N, by = 3)]
}

sam_x_thin <- thin_sam(gibbs_sam$X)
sam_y_thin <- thin_sam(gibbs_sam$Y)

# Posterior summary function
post_summary <- function(sam_thin, variable) {
  post_med <- median(sam_thin)
  cred_int <- quantile(sam_thin, probs = c(0.025, 0.975))
  cat("Posterior median for", variable, ":", post_med, "\n")
  cat("95% credible interval for", variable, ":", cred_int, "\n")
}

# Posterior summaries for X and Y
post_summary(sam_x_thin, "X")
post_summary(sam_y_thin, "Y")

# Trace plot function for the sampled values
plot_trace <- function(sam_thin, variable) {
  plot(sam_thin, type = 'l', xlab = "Iteration", ylab = paste("Sampled values (", variable, ")", sep = ""),
       main = paste("Trace Plot of Gibbs Samples for", variable))
}

# Trace plots for X and Y
plot_trace(sam_x_thin, "X")
plot_trace(sam_y_thin, "Y")


####################################################
##### Gibbs Sampling Algorithm - (Flu shot data)
####################################################

# Data for the observed 19 individuals
flu_data <- c(1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1)
y_obs <- sum(flu_data)  # Count of "successes"
tot_draws <- 10000     # Number of MCMC draws

# Initial values for the quantities of interest
X20_init <- 1
theta_init <- 0.5

# Preallocate vectors for the samples
X20_vec <- numeric(tot_draws)
theta_vec <- numeric(tot_draws)

# Set initial values
X20_vec[1] <- X20_init
theta_vec[1] <- theta_init

# Gibbs sampling loop
for (j in 2:tot_draws) {
  theta_vec[j] <- rbeta(1, y_obs + X20_vec[j-1] + 1, 20 - y_obs - X20_vec[j-1] + 1)
  X20_vec[j] <- rbinom(1, size = 1, prob = theta_vec[j])
}

# Burn-in removal: remove first 2000 draws
burn_in <- 2000
theta_post <- theta_vec[-(1:burn_in)]
X20_post <- X20_vec[-(1:burn_in)]

## Posterior summary for theta

# Plot posterior density of theta
plot(density(theta_post), main = "Posterior Density",
     xlab = expression(theta), ylab = "density")

# Posterior mean and median
theta_mean <- mean(theta_post)
theta_med <- median(theta_post)

# 95% quantile-based credible interval
theta_quant_int <- quantile(theta_post, probs = c(0.025, 0.975))

# 95% HPD interval
theta_hpd_int <- emp.hpd(theta_post, conf = 0.95)

# Display posterior summaries for theta
cat("Posterior Mean of Theta:", theta_mean, "\n")
cat("Posterior Median of Theta:", theta_med, "\n")
cat("95% Quantile-based interval for Theta: ", print_int(theta_quant_int), "\n",sep="")
cat("95% HPD interval for Theta: ", print_int(theta_hpd_int), "\n",sep="")

# Observed data MLE for theta
mle_theta <- mean(flu_data)
cat("Observed-data MLE for Theta:", mle_theta, "\n")

## Posterior summary for the missing data value X20

# Posterior mean and variance for X20
X20_mean <- mean(X20_post)
X20_var <- var(X20_post)

# Display posterior summaries for X20
cat("Posterior Mean of X20:", X20_mean, "\n")
cat("Posterior Variance of X20:", X20_var, "\n")

####################################################
##### Gibbs Sampling Algorithm - (Coal mining data)
####################################################
# Data for the disasters
disasters <- c(4,5,4,1,0,4,3,4,0,6,3,3,4,0,2,6,3,3,5,
               4,5,3,1,4,4,1,5,5,3,4,2,5,2,2,3,4,2,1,
               3,2,2,1,1,1,1,3,0,0,1,0,1,1,0,0,3,1,0,
               3,2,2,0,1,1,1,0,1,0,1,0,0,0,2,1,0,0,0,
               1,1,0,2,3,3,1,1,2,1,1,1,1,2,4,2,0,0,0,
               1,4,0,0,0,1,0,0,0,0,0,1,0,0,1,0,1)

# Hyperparameters for the prior distributions
alpha_pri <- 4
beta_pri <- 1
gam_pri <- 1
del_pri <- 2

# Gibbs sampler function for sampling theta = (lam, phi, k)
bcp <- function(theta_mat, y, a, b, g, d) {
  n <- length(y)
  for (i in 2:nrow(theta_mat)) {
    # Sampling lam from the gamma distribution
    lam <- rgamma(1, a + sum(y[1:theta_mat[i-1, 3]]), b + theta_mat[i-1, 3])
    
    # Sampling phi from the gamma distribution
    phi <- rgamma(1, g + sum(y[theta_mat[i-1, 3]:n]), d + n - theta_mat[i-1, 3])
    
    # Compute probabilities for k
    k_prob <- sapply(1:n, function(j) exp(j * (phi - lam)) * (lam / phi) ^ sum(y[1:j]))
    k_prob <- k_prob / sum(k_prob)
    
    # Sample k from the computed probabilities
    k <- sample(1:n, size = 1, prob = k_prob)
    
    # Store the sampled parameters for this iteration
    theta_mat[i, ] <- c(lam, phi, k)
  }
  return(theta_mat)
}

# Total number of MCMC draws
tot_draws <- 2000

# Initial parameter values (lam, phi, k)
init_par_vals <- c(4, 0.5, 55)  # Using prior means as initial values
init_theta_mat <- matrix(0, nrow = tot_draws, ncol = 3)
init_theta_mat <- rbind(init_par_vals, init_theta_mat)

# Run the Gibbs sampler
gibbs_sam <- bcp(init_theta_mat, y = disasters, a = alpha_pri, 
                        b = beta_pri, g = gam_pri, d = del_pri)

# Remove the first 1000 iterations as burn-in
gibbs_post <- gibbs_sam[-(1:1000), ]

# Posterior summaries for lam, phi, and k
post_meds <- apply(gibbs_post, 2, median)  # Posterior medians
post_quants <- cbind(apply(gibbs_post, 2, quantile, probs = 0.025), apply(gibbs_post, 2, quantile, probs = 0.975))  # 95% quantile intervals
post_hpd <- rbind(emp.hpd(gibbs_post[, 1], conf = 0.95), 
                       emp.hpd(gibbs_post[, 2], conf = 0.95), 
                       emp.hpd(gibbs_post[, 3], conf = 0.95))  # 95% HPD intervals

# Display posterior summaries
cat("Posterior Medians:\n", post_meds, "\n")

param= c("lamda", "phi", "k")
# 95% Quantile-based intervals
cat("95% Quantile-based intervals:\n")
for (i in 1:nrow(post_quants)) {
  cat(param[i], ": ", print_int(post_quants[i,]) , "\n", sep="")
}

# 95% HPD intervals
cat("95% HPD intervals:\n")
for (i in 1:nrow(post_hpd)) {
  cat(param[i], ": ", print_int(post_hpd[i,]) , "\n", sep="")
}

########################################################
########################################################
# R Code for Chapter 7 Content
########################################################
########################################################

### Monte Carlo Algorithm Example
set.seed(84375)
mc_tour <- data.frame(mu = rnorm(5000, mean = 4, sd = 0.6))
ggplot(mc_tour, aes(x = mu)) + 
  geom_histogram(aes(y = after_stat(density)), color = "white", bins = 15) + 
  stat_function(fun = dnorm, args = list(4, 0.6), color = "blue") +
  labs(x=expression(mu))

### Metropolis-Hastings Algorithm Example
# Step 1: Proposal for the next tour stop
set.seed(8)
current=3
proposal <- runif(1, min = current - 1, max = current + 1)
proposal

# Posterior plausibility calculations
pro_plaus <- dnorm(proposal, 0, 1) * dnorm(6.25, proposal, 0.75)
cur_plaus  <- dnorm(current, 0, 1) * dnorm(6.25, current, 0.75)
alpha <- min(1, pro_plaus / cur_plaus)
alpha

# Coin flip to accept or reject the proposal
next_stop <- sample(c(proposal, current), size = 1, prob = c(alpha, 1-alpha))
next_stop

### Function for One Metropolis-Hastings Iteration

one_MH_iter <- function(w, current){
  # Step 1: Propose the next chain location
  proposal <- runif(1, min = current - w, max = current + w)
  
  # Step 2: Decide whether or not to go there
  pro_plaus <- dnorm(proposal, 0, 1) * dnorm(6.25, proposal, 0.75)
  cur_plaus  <- dnorm(current, 0, 1) * dnorm(6.25, current, 0.75)
  alpha <- min(1, pro_plaus / cur_plaus)
  next_stop <- sample(c(proposal, current), size = 1, prob = c(alpha, 1-alpha))
  
  # Return the results
  return(data.frame(proposal, alpha, next_stop))
}

### Metropolis-Hastings Simulation for N Iterations

MH_tour <- function(N, w){
  # 1. Start the chain at location 3
  current <- 3
  
  # 2. Initialize the simulation
  mu <- rep(0, N)
  
  # 3. Simulate N Markov chain stops
  for(i in 1:N){    
    # Simulate one iteration
    sim <- one_MH_iter(w = w, current = current)
    
    # Record next location
    mu[i] <- sim$next_stop
    
    # Reset the current location
    current <- sim$next_stop
  }
  
  # 4. Return the chain locations
  return(data.frame(iteration = c(1:N), mu))
}

# Example call
set.seed(84735)
MH_sim_1 <- MH_tour(N = 5000, w = 1)

# Plot the results
ggplot(MH_sim_1, aes(x = iteration, y = mu)) + geom_line() +
  labs(y=expression(mu), title="Trace plot of the Markov Chain")

ggplot(MH_sim_1, aes(x = mu)) + 
  geom_histogram(aes(y = after_stat(density)), color = "white", bins = 20) + 
  stat_function(fun = dnorm, args = list(4,0.6), color = "blue") +
  xlab(expression(mu))

################################################# 
### Metropolis-Hastings for Beta-Binomial Model
#################################################
# Function for one iteration of the Beta-Binomial model
one_iter <- function(a, b, current) {
  # Step 1: Propose the next chain location using vectorized dbeta and dbinom
  proposal <- rbeta(1, a, b)
  
  # Step 2: Calculate plausibilities and acceptance probability
  plausibilities <- dbeta(proposal, 2, 3) * dbinom(1, 2, proposal) / dbeta(current, 2, 3) / dbinom(1, 2, current)
  q_ratio <- dbeta(current, a, b) / dbeta(proposal, a, b)
  alpha <- min(1, plausibilities * q_ratio)
  
  # Step 3: Accept or reject proposal
  next_stop <- ifelse(runif(1) < alpha, proposal, current)
  
  return(next_stop)
}

# Function for the full Beta-Binomial tour using vectorization
beta_bin_tour <- function(N, a, b, initial = 0.5) {
  # Initialize chain with first location
  chain <- numeric(N)
  chain[1] <- initial
  
  # Simulate N Markov chain stops efficiently using vectorization
  for (i in 2:N) {
    chain[i] <- one_iter(a = a, b = b, current = chain[i-1])
  }
  
  # Return the chain locations as a tibble for easier manipulation
  return(tibble(iteration = seq_len(N), pi = chain))
}

# Example call
set.seed(84735)
beta_bin_sim <- beta_bin_tour(N = 5000, a = 1, b = 1)

# Plot the results using modern ggplot2 practices
ggplot(beta_bin_sim, aes(x = iteration, y = pi)) +
  geom_line() +
  labs(title = "Trace Plot of Beta-Binomial Chain", x = "Iteration", y = expression(pi)) +
  theme_minimal()

ggplot(beta_bin_sim, aes(x = pi)) +
  geom_histogram(aes(y = after_stat(density)), color = "white", bins = 30) +
  stat_function(fun = dbeta, args = list(3, 4), color = "blue") +
  labs(title = expression(paste("Posterior Distribution of ", pi)), x = expression(pi), y = "density") +
  theme_minimal()


########################################################
########################################################
# R Code for Chapter 6 Content
########################################################
########################################################

# Load packages
library(tidyverse)
library(janitor)
library(rstan)
#may need to install the below packages too, not necessary to call them with library command though
#install.packages("Rcpp")
library(bayesplot)

#################################################
### 6.2 Markov chains via rstan
#################################################
#### Beta-Binomial example
#################################################

# Step 1: DEFINE the model
bb_mod <- "
  data {
    int<lower = 0, upper = 10> Y;
  }
  parameters {
    real<lower = 0, upper = 1> pi;
  }
  model {
    Y ~ binomial(10, pi);
    pi ~ beta(2, 2);
  }
"

# Step 2: SIMULATE the posterior
bb_sim <- stan(model_code = bb_mod, data = list(Y = 9), 
               chains = 4, iter = 5000 * 2, seed = 84735)

# Diagnostics: Trace plot for 'pi'
mcmc_trace(bb_sim, pars = "pi", size = 0.1) +
  labs(title="Trace plots of the 4 Markov Chains",
       x="iteration",y=expression(pi))

# Histogram of the Markov chain values for 'pi'
mcmc_hist(bb_sim, pars = "pi") + 
  yaxis_text(TRUE) + 
  labs(title="Histogram of the Sampled Values",
       x=expression(pi), y="count")

# Density plot of the Markov chain values for 'pi'
mcmc_dens(bb_sim, pars = "pi") + 
  yaxis_text(TRUE)  + 
  labs(title="Density of the Sampled Values",
       x=expression(pi), y="count")

#################################################
### Gamma-Poisson Example
#################################################

# Step 1: DEFINE the model
gam_pois_mod <- "
  data {
    int<lower = 0> Y[2];
  }
  parameters {
    real<lower = 0> lambda;
  }
  model {
    Y ~ poisson(lambda);
    lambda ~ gamma(3, 1);
  }
"

# Step 2: SIMULATE the posterior
gam_pois_sim <- stan(model_code = gam_pois_mod, data = list(Y = c(2, 8)), 
                     chains = 1, iter = 5000 * 2, seed = 84735)

# Diagnostics: Trace plot for 'lambda'
mcmc_trace(gam_pois_sim, pars = "lambda", size = 0.1) +
  labs(title="Trace plo of the Markov Chain",
       x="iteration",y=expression(lambda))

# Histogram of the Markov chain values for 'lambda'
mcmc_hist(gam_pois_sim, pars = "lambda") + 
  yaxis_text(TRUE) + 
  labs(title="Histogram of the sampled values",
       x=expression(lambda), y="count")

# Density plot of the Markov chain values for 'lambda'
mcmc_dens(gam_pois_sim, pars = "lambda") + 
  yaxis_text(TRUE)  + 
  labs(title="Density of the sampled values",
       x=expression(lambda), y="count")

# Diagnostics: Effective sample size, Autocorrelation, R-hat

# Effective sample size ratio for lambda
neff_ratio(gam_pois_sim, pars = "lambda")

# Autocorrelation plot for lambda
mcmc_acf(gam_pois_sim, pars = "lambda")  + 
  labs(y="acf values", x= "lag")

# R-hat statistic for lambda
rhat(gam_pois_sim, pars = "lambda")


