## Data Generating Process -----------------------------------------------------
randn <- function(...) {
    as.matrix(structure(rnorm(prod(...)), dim = c(...)))
}

sim_hetero_normal <- function(n, beta, a0, b0, rho) {
    # Normal regression with heteroskedasticity
    lambda <- rgamma(n, rho, rho)
    sigma2 <- MCMCpack::rinvgamma(n, a0, b0)
    X <- randn(n, length(beta))
    y <- rnorm(n, X %*% beta, sqrt(sigma2 / lambda))
    list(y = y, X = X, beta = beta, a0 = a0, b0 = b0, rho = rho)
}

sim_data <- function(rho, seed = 123) {
    set.seed(seed)  # for reproducibility
    param <- list(n = 300, beta = runif(5, min = -3, max = 3),
                  a0 = 30, b0 = 3, rho = rho)
    do.call(sim_hetero_normal, param)
}


## Inference -------------------------------------------------------------------
# Take two prior input parameters (prior mean and variance of beta[1]) and
# return the posterior mean
prior_to_posterior <- function(input, data, summary_fun) {
    p <- length(data$beta)
    vary <- list(b0 = as.matrix(numeric(p)), B0 = diag(p), a0 = 3, d0 = 3, v = 3)
    vary$b0[1] <- input[[1]]    # input: prior mean of beta[1]
    vary$B0[1,1] <- 10^input[[2]]  # input: prior variance of beta[1], in exponents to base 10

    fix <- list(data = data, n.iter = 1000)
    res <- do.call(gibbs_gaussian, append(vary, fix))

    summary_fun(res$beta[, 1]) # output: posterior mean of beta[1]
}


# Gibbs inference
gibbs_gaussian <- function(data, b0, B0, a0, d0, v, n.iter = 10000) {
    y <- as.numeric(data$y)
    X <- as.matrix(data$X)
    beta <- b0
    sigma2 <- 1

    sigma2_list <- beta_list <- array(list(), n.iter)
    n <- length(y)
    ap <- a0 + n / 2
    inv_B0 <- solve(B0)
    inv_B0_times_b0 <- inv_B0 %*% b0
    # pb <- txtProgressBar(1, n.iter, style = 3)

    for (iter in 1:n.iter) {
        # Sample lambda
        res <- as.matrix(y - X %*% beta)
        lambda <- rgamma(
            n, shape = v + 0.5,
            scale = 1 / (v + res^2 / (2 * sigma2))
        )

        # Sample sigma2
        bp <- as.vector(d0 + 0.5 * t(res) %*% (lambda * res))
        sigma2 <- 1 / rgamma(1, shape = ap, scale = 1 / bp)

        # Sample beta
        Bp <- solve(sigma2^(-1) * t(X) %*% (lambda * X) + inv_B0)
        bp2 <- Bp %*% (sigma2^(-1) * t(X) %*% (lambda * y) + inv_B0_times_b0)
        beta <- as.vector(mvtnorm::rmvnorm(1, bp2, as.matrix(Bp)))

        # Keep track
        sigma2_list[[iter]] <- sigma2
        beta_list[[iter]] <- beta
        # setTxtProgressBar(pb, value = iter)
    }

    list(
        sigma2 = do.call(c, sigma2_list),
        beta = t(do.call(cbind, beta_list))
    )
}


# This function computes the posterior statistics from the MCMC draws
mcmc_posterior <- function(draws, stat_fun = mean, burn_ins) {
    if (missing(burn_ins)) burn_ins <- round(length(draws) / 10)
    stat_fun(tail(draws, -burn_ins))
}

# This function is for computing the variance of the posterior mean estimator
mcmc_slicing <- function(draws, stat_fun, burn_ins, block_size = 50) {
    if (missing(burn_ins)) {
        burn_ins <- round(length(draws) / 10)
    }
    draws <- tail(draws, -burn_ins)

    sapply(1:block_size, function(start_index) {
        s <- seq(start_index, length(draws), by = block_size)
        stat_fun(draws[s])
    })
}

mcmc_sliced_mean_with_sd <- function(draws, ...) {
    x <- mcmc_slicing(draws, stat_fun = mean, ...)
    result <- mean(x)
    attr(result, "sigma") <- sd(x)
    result
}


## Utility ----------------------------------------------------------------------
set_default <- function(f, ...) {
    args <- list(...)
    function(...) {
        args_2 <- list(...)
        do.call(f, append(args, args_2))
    }
}

grid <- function(...) {
    data.frame(expand.grid(...))
}

exterior <- function(x) {
    x <- as.matrix(x)
    keep <- c()
    for (i in 1:ncol(x)) {
        ext <- range(x[, i])
        keep <- c(keep, which(x[, i] %in% ext))
    }
    x[unique(keep), ]
}


## Plot functions --------------------------------------------------------------
library(plotly)

plot_manifold <- function(prior, model, ...) {
    create_plot_data <- function(prior, posterior) {
        setNames(data.frame(prior, posterior), c("x", "y", "z"))
    }

    posterior_mean <- predict_gp(model$model, prior)$mean
    manifold_grid <- create_plot_data(prior, posterior_mean)
    manifold_grid[, 2] <- 10 ^ manifold_grid[, 2]
    p <- plot_surface(manifold_grid, FALSE, ...)

    evaluated_points <- create_plot_data(model$X, model$ys)
    evaluated_points[, 2] <- 10 ^ evaluated_points[, 2]
    wrap_add_markers(p, evaluated_points)
}

plot_surface <- function(x, add_markers = TRUE, ...) {
    p <- surface3d(
        x,
        xlab = "Prior mean of beta[1]",
        ylab = "Prior variance of beta[1]",
        zlab = "Posterior mean of beta[1]",
        ...
    )
    if (add_markers) {
        p <- wrap_add_markers(p, x)
    }
    p %>%
        layout(scene = list(
            yaxis = list(type = "log"),
            camera = list(eye = list(x = -1.5, y = 1, z = 0.5))
        ))
}

# A function for 3D scatterplot
# # Basic examples
# scatter3d(data = iris)
# scatter3d(data = mtcars, ylab = "some y label")
# scatter3d(data = iris, x = "Sepal.Length", y = "Sepal.Width", z = "Petal.Width")
#
# # Overlay
# scatter3d(data = mtcars) %>%
#   wrap_add_markers(data = iris)
scatter3d <- function(data, x = NULL, y = NULL, z = NULL,
                      xlab = NULL, ylab = NULL, zlab = NULL,
                      titlefont = 12, tickfont = 12, ...) {
    plot_ly(...) %>%
        wrap_add_markers(data, x, y, z) %>%
        layout(scene = list(
            xaxis = list(title = xlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            yaxis = list(title = ylab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            zaxis = list(title = zlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont))
        ))
}

wrap_add_markers <- function(p, data, x = NULL, y = NULL, z = NULL) {
    if (is.null(x)) x <- names(data)[1]
    if (is.null(y)) y <- names(data)[2]
    if (is.null(z)) z <- names(data)[3]
    add_markers(
        p,
        x = as.formula(paste("~", x)),
        y = as.formula(paste("~", y)),
        z = as.formula(paste("~", z)),
        data = data
    )
}

# A function for 3D surface
# **NOTE**: In `wrap_add_surface`, cols 1, 2 and 3 in `data` are mapped (after
# some transformation) to x, y, and z in `plotly::add_surface`. It is assumed
# `data` is obtained from some sort of `expand.grid(x, y)` (and in that order)
# with z being the corresponding value.
# # Basic examples
# g <- expand.grid(x = 1:10, y = 5:20)
# g$z <- g$x^2 + g$y^2
# g
# surface3d(g)
# surface3d(g, zlab = "some z label")
# # overlay
# surface3d(g) %>%
#   wrap_add_surface(2 * g)
surface3d <- function(data, xlab = NULL, ylab = NULL, zlab = NULL,
                      titlefont = 24, tickfont = 14, ...) {
    plot_ly(...) %>%
        wrap_add_surface(data) %>%
        layout(scene = list(
            xaxis = list(title = xlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            yaxis = list(title = ylab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            zaxis = list(title = zlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont))
        ))
}

wrap_add_surface <- function(p, data) {
    .y <- data[[2]]

    add_surface(
        p,
        x = unique(data[[1]]),
        y = unique(.y),
        z = matrix(data[[3]], nrow = length(unique(.y)), byrow = TRUE)
    )
}


#===============================================================================
# Bayesian linear regression with heteroskedasticity
#
# Generate posterior surfaces for different levels of heteroskedasticity
# This part shows that variation in heteroskedasticity can lead to different
# geometry in the posterior manifold.
library(gptools2)

# Hyperparameters setup
rho_seq <- 2^seq(-3, 0, 0.5)

# Compare grid approach and active learning approach
# 1. Grid evaluation -----------------------------------------------------------
grid_approach <- function(rho, prior_grid) {
    data0 <- sim_data(rho)
    p2p <- set_default(prior_to_posterior,
                       data = data0,
                       summary_fun = mcmc_posterior)
    posterior <- map_row(prior_grid, p2p)

    list(data = data0,
         posterior = data.frame(
             prior = prior_grid,
             posterior = posterior))
}

# For each rho, simulate data and generate the ground truth
prior_grid <- grid(b0 = seq(-5, 5, 0.5),
                   B0 = seq(-2, 2, 0.5))
tictoc::tic()
experiments <- lapply(rho_seq, grid_approach, prior_grid = prior_grid)
tictoc::toc()  # 428.057 sec elapsed


# Plots
posterior_surfaces <- experiments |>
    lapply(function(x) {
        x$posterior[, 2] <- 10^x$posterior[, 2]
        plot_surface(x$posterior, add_markers = FALSE,
                     titlefont = 32, tickfont = 16)
    })
print(rho_seq)
print(posterior_surfaces[[1]])
print(posterior_surfaces[[2]])
print(posterior_surfaces[[3]])


# Time for single run
tictoc::tic()
lapply(0.25, grid_approach, prior_grid = prior_grid)
tictoc::toc()  # 61.513 sec elapsed


# 2. GP with active learning ---------------------------------------------------
data0 <- sim_data(0.25)

# # One could start with a small initial grid
# init_X <- as.matrix(grid(b0 = seq(-5, 5, length.out = 5),
#                          B0 = seq(-2, 2, length.out = 5)))  # 25 points
#
# But to mitigate "boundary effect", it is recommended to spend the initial grid
# evaluation budget on the boundary.
init_X <- exterior(grid(b0 = seq(-5, 5, length.out = 7),
                        B0 = seq(-2, 2, length.out = 7)))  # 24 points

f <- function(x) {
    prior_to_posterior(x, data0, mcmc_sliced_mean_with_sd)
}

# Initial run to compute the convergence threshold if one does not know what
# tolerance to set
set.seed(123)
tictoc::tic()
model <- active_learning(init_X, f, max_iter = 0, sample_n = 100)

# Continue evaluation further
model <- active_learning(init_X, f, restart = model,
                         max_iter = 100, tol = c(1, 2 * 0.09),  # 2 standard derivations
                         sample_n = 100, persistence = 10, consecutive = 3)
tictoc::toc()  # 22.322 sec elapsed

# Plots
prior_grid_HD <- as.matrix(grid(b0 = seq(-5, 5, 0.5),
                                B0 = seq(-2, 2, 0.1)))
plot_manifold(prior_grid_HD, model)


# Extra: Compare using different kernel functions ==============================
# - Rational qradratic kernel
set.seed(123)
tictoc::tic()
model_rq <- active_learning(init_X, f,
                            kernel = rational_quadratic_kernel(),
                            max_iter = 100, tol = c(1, 2 * 0.09),  # 2 standard derivations
                            sample_n = 100, persistence = 10, consecutive = 3)
tictoc::toc()  # 30.024 sec elapsed
plot_manifold(prior_grid_HD, model_rq)


# - Matern kernel
set.seed(123)
tictoc::tic()
model_mt <- active_learning(init_X, f,
                            kernel = matern_kernel(),
                            max_iter = 100, tol = c(1, 2 * 0.09),  # 2 standard derivations
                            sample_n = 100, persistence = 10, consecutive = 3)
tictoc::toc()
plot_manifold(prior_grid_HD, model_mt)  # 34.865 sec elapsed


# Evaluation points comparision and uncertainty reduction plot =================
# Grid plot
experiments[3] |>  # rho = 0.25
    lapply(function(x) {
        x$posterior[, 2] <- 10^x$posterior[, 2]
        plot_surface(x$posterior, add_markers = TRUE,
                     titlefont = 24, tickfont = 16)
    })

# Active Learning plot
plot_manifold(prior_grid_HD, model,
              titlefont = 24, tickfont = 16)

# Uncertainty reduction plot
new_uncertainty_tracker <- function(env = new.env()) {
    env$uncertainty <- c()
    function(model) {
        if (!missing(model)) {
            env$uncertainty <- rbind(env$uncertainty, model$measure)
        }
        return(env$uncertainty)
    }
}
uncertainty_tracker <- new_uncertainty_tracker()
invisible(active_learning(init_X, f,
                          max_iter = 100, tol = c(1, 2 * 0.09),  # 2 standard derivations
                          sample_n = 100, persistence = 10, consecutive = 3,
                          callback = uncertainty_tracker))
uncertainty_tracker()
png("output/GP_uncertainty.png", width = 650, height = 650)
par(las = 1, family = "DejaVu Sans")
par(mar = c(5.1, 6.1, 4.1, 2.1))
font_size <- 2.5
uc <- uncertainty_tracker()
plot(x = 1:nrow(uc), y = uc[, 2], pch = 19,
     xlab = "Number of evaluated AL points", ylab = "",
     cex.lab = font_size, cex.axis = font_size - 0.5,
     cex = 2)
title(ylab = "Maximum uncertainty over the surface",
      line = 4, cex.lab = font_size)
dev.off()
