library(magrittr)
library(gptools2)

## DGP and Data ----------------------------------------------------------------
# Bernoulli distribution
set.seed(123)  # for reproducibility
n <- 20
y <- rbinom(n = n, size = 1, prob = 0.75)


## Posterior distribution ------------------------------------------------------
# Analytical posterior mean for the Bernoulli distribution with beta prior
posterior_mean <- function(alpha, beta, y) {
    n       <- length(y)
    sum_y   <- sum(y)
    alpha_1 <- alpha + sum_y
    beta_1  <- beta + (n - sum_y)
    alpha_1 / (alpha_1 + beta_1)
}

# MCMC posterior mean for the Bernoulli distribution with beta prior, along the
# line alpha = beta, relabeled as phi.
posterior_mean_via_mcmc <- function(phi, y, seed = 1234, method = "gibbs") {
    if (method == "gibbs") {
        result <- mcmc_gibbs(phi, phi, y, 10000, seed = seed)
    } else {
        result <- mcmc_metropolis(phi, phi, y, 10000, seed = seed)
    }
    sliced_means <- mcmc_slicing(result$draws, mean)

    p_mean <- mean(sliced_means)
    attr(p_mean, "sigma") <- sd(sliced_means)
    p_mean
}

# MCMC with Gibbs sampler
mcmc_gibbs <- function(a0, b0, y, n_iter, seed = 1234) {
    if (!missing(seed)) set.seed(seed)
    # The Gibbs sampler happens to reduce to the exact posterior distribution
    draws <- rbeta(n_iter, a0 + sum(y), b0 + length(y) - sum(y))
    list(data = y, alpha = a0, beta = b0, n_iter = n_iter, draws = draws)
}

# MCMC with Metropolis algorithm
mcmc_metropolis <- function(a0, b0, y, n_iter, seed = 1234) {
    if (!missing(seed)) set.seed(seed)

    logprior <- function(phi) {
        (a0 - 1) * log(phi) + (b0 - 1) * log(1 - phi) - log(beta(a0, b0))
    }

    loglikelihood <- function(phi) {
        sum(y) * log(phi) +  (length(y) - sum(y)) * log(1 - phi)
    }

    log_likelihood_x_prior <- function(phi) {
        loglikelihood(phi) + logprior(phi)
    }

    draws <- runif(1)
    for (i in 1:n_iter) {
        old_phi <- tail(draws, 1)
        new_phi <- runif(1)
        log_alpha <- log_likelihood_x_prior(new_phi) - log_likelihood_x_prior(old_phi)
        if (log(runif(1)) >= log_alpha) {
            new_phi <- old_phi
        }
        draws <- c(draws, new_phi)
    }

    list(data = y, alpha = a0, beta = b0, n_iter = n_iter, draws = draws)
}

# This function computes the posterior statistics from the MCMC draws
mcmc_posterior <- function(draws, stat_fun, burn_ins) {
    if (missing(burn_ins)) burn_ins <- round(length(draws) / 10)
    stat_fun(tail(draws, -burn_ins))
}

# This function is for computing the variance of the posterior mean estimator
mcmc_slicing <- function(draws, stat_fun, burn_ins, block_size = 50) {
    if (missing(burn_ins)) {
        burn_ins <- round(length(draws) / 10)
    }
    draws <- tail(draws, -burn_ins)

    sapply(1:block_size, function(start_index) {
      s <- seq(start_index, length(draws), by = block_size)
      stat_fun(draws[s])
    })
}

# Derivative of `posterior_mean` via automatic differentiation
d_posterior_mean_d_phi_AD <- function(phi, y) {
    alpha_eq_beta <- function(phi) posterior_mean(phi, phi, y)
    ADtools::auto_diff(alpha_eq_beta,
                       wrt = "phi",
                       at = list(phi = phi, y = y))@dx
}

# Alternative: Derivative of `posterior_mean_via_mcmc` for the 'gibbs' case via
# finite-differencing
d_posterior_mean_d_phi_FD <- function(phi, y) {
    # For simplicity, we use finite-differencing here
    derivative <- ADtools::finite_diff(posterior_mean_via_mcmc,
                                       wrt = "phi",
                                       at = list(phi = phi, y = y,
                                                 seed = 1234, method = "gibbs"))
    # Wrap into a list for post-processing
    list(derivative)
}


# Basic check
separator <- function() rep("=", getOption("width"))

local({
    message(separator())
    message("Basic check: computing the posterior mean for alpha = 2, beta = 3 and the given data y")
    message("Analytical formula: ", posterior_mean(2, 3, y))
    # MCMC estimation
    result <- mcmc_metropolis(2, 3, y, 1000)
    message("MCMC estimate:", round(mcmc_posterior(result$draws, mean), 4))

    sliced_means <- mcmc_slicing(result$draws, mean)
    message("MCMC sliced estimate (with sd):", round(mean(sliced_means), 4),
            "(", round(sd(sliced_means), 4), ")")
    message(separator())
})


## Plotting functions -----------------------------------------------------------
#' Plot of the true posterior curve along the direction alpha = beta
plot_truth <- function(x = seq(0.1, 101, 1), y, new = TRUE,
                       text_size = 2, ...) {
    par(las = 1, family = "DejaVu Sans")
    par(mar = c(5.1, 7.1, 4.1, 2.1))
    data0 <- data.frame(alpha = x,
                        beta = x,
                        pmean = posterior_mean(x, x, y))
    if (!new) {
        lines(x, data0$pmean, lty = 2, ...)
        return(invisible(NULL))
    }
    plot(data0$alpha, data0$pmean, type = 'l',
         xlab = "alpha (= beta)",
         ylab = "", lty = 2,
         xlim = c(0, 100),
         ylim = c(0.49, 0.65),
         cex.lab = text_size,
         cex.axis = text_size,
         cex.main = text_size,
         cex.sub = text_size,
         ...)  # zoom in for more visible bounds
    title(ylab = "posterior mean",
          line = 5, cex.lab = text_size)
}

plot_confidence_band <- function(x, model, predict_fun = predict_gp, ...) {
    prediction <- predict_fun(model, as.matrix(x))
    p_mean <- prediction$mean
    p_sd <- sqrt(abs(diag(prediction$covariance)))
    polygon(c(x, rev(x)),
            c(p_mean - 2 * p_sd, rev(p_mean + 2 * p_sd)),
            border = NA, ...)
}

plot_predictive_mean <- function(x, model, predict_fun = predict_gp, ...) {
    prediction <- predict_fun(model, as.matrix(x))
    p_mean <- prediction$mean
    lines(x, p_mean, ...)
}

plot_training_points <- function(model, ...) {
    points(model$train_X, model$train_y, pch = 19, ...)
}


## GP ------------------------------------------------------------------------
# Setup
output_file <- "./output/binomial_beta_before_AL.png"
phi         <- c(0.1, 10, 40, 70, 100)
use_AL      <- TRUE

# Generate training data (Rerun from here after reaching the end)
train_theta     <- as.matrix(phi)
train_h_mcmc    <- as.matrix(sapply(phi, posterior_mean_via_mcmc, y = y))
train_h_mcmc_sd <- sapply(phi, function(phi) attr(posterior_mean_via_mcmc(phi, y = y), "sigma"))
ground_truth    <- posterior_mean(phi, phi, y)

# Inference
model <- gp(X = train_theta,
            y = train_h_mcmc,
            kernel = squared_exponential(1, mean(phi)),
            sigma = train_h_mcmc_sd)
new_phi <- as.matrix(seq(0.1, 100, 1))
phi_max <- new_phi[which.max(diag(predict_gp(model, new_phi)$covariance)), ]


# Make plot (Active learning with Gaussian Process)
new_phi <- as.matrix(seq(0.1, 101, 1))

png(filename = output_file, width = 1000, height = 500)
plot_truth(y = y, text_size = 2,
           main = "(A) Active learning with Gaussian Process")

plot_confidence_band(new_phi, model, col = "lightskyblue2")
plot_predictive_mean(new_phi, model, col = "blue", lwd = 3)

plot_truth(y = y, new = FALSE, lwd = 3)
abline(v = phi_max, col = "red", lwd = 2)

plot_training_points(model, cex = 1.5)
legend("topright",
       legend = c("truth", "GP", "AL point"),
       col = c("black", "blue", "red"),
       lwd = c(3, 3, 2),
       lty = c(2, 1, 1),
       cex = 1.5)
dev.off()


# GP with derivative -----------------------------------------------------------
# Generate data
phi <- c(phi, phi_max)
output_file <- "./output/binomial_beta_after_AL.png"

train_theta     <- as.matrix(phi)
train_h_mcmc    <- as.matrix(sapply(phi, posterior_mean_via_mcmc, y = y))
train_h_mcmc_sd <- sapply(phi, function(phi) attr(posterior_mean_via_mcmc(phi, y = y), "sigma"))
# dy <- do.call(rbind, lapply(phi, d_posterior_mean_d_phi_AD, y = y))
dy <- do.call(rbind, lapply(phi, d_posterior_mean_d_phi_FD, y = y))  # Alternative

# Inference
model <- gp(X = train_theta,
            y = train_h_mcmc,
            kernel = squared_exponential(1, mean(phi)),
            sigma = train_h_mcmc_sd)
# Inference with derivative
model_2 <- gp_d(X = train_theta,
                y = train_h_mcmc,
                dy = dy,
                kernel = squared_exponential_d(1, mean(phi)),
                sigma = train_h_mcmc_sd,
                sigma_d = 0)  # 0 for AD, 1e-8 for FD
# Make plot (Gaussian Process with derivative)
new_phi <- as.matrix(seq(0.1, 101, 1))

png(filename = output_file, width = 1000, height = 500)
plot_truth(y = y, text_size = 2,
           main = "(B) Precision gain from active learning and derivative information")

plot_confidence_band(new_phi, model, col = "lightskyblue2")
plot_confidence_band(new_phi, model_2, predict_gp_d, col = "palegreen1")

plot_predictive_mean(new_phi, model, col = "blue", lwd = 3)
plot_predictive_mean(new_phi, model_2, predict_gp_d, col = "darkgreen", lwd = 3)

plot_truth(y = y, new = FALSE, lwd = 3)
abline(v = phi_max, col = "red", lwd = 2)
plot_training_points(model, cex = 1.5)
legend("topright",
       legend = c("truth", "GP", "GP with derivative", "AL point"),
       col = c("black", "blue", "darkgreen", "red"),
       lwd = c(3, 3, 3, 2),
       lty = c(2, 1, 1, 1),
       cex = 1.5)
dev.off()
