# Helpers functions
sim_bayesian_linear_regression <- function(n, p, b0, B0, a0, d0, seed = NULL) {
    if (!is.null(seed)) set.seed(seed)

    sigma2 <- MCMCpack::rinvgamma(n, a0/2, d0/2)
    stopifnot(length(sigma2) == n)

    beta <- do.call(rbind, lapply(sigma2, \(s2) rnorm(p, b0, s2 * diag(B0))))
    stopifnot(nrow(beta) == n && ncol(beta) == p)

    X <- matrix(rnorm(n * p), n, p)
    y <- rnorm(n, rowSums(X * beta), sigma2)
    list(X = X, y = y, beta = beta, sigma2 = sigma2,
         n = n, p = p, b0 = b0, B0 = B0, a0 = a0, d0 = d0, seed = seed)
}


gibbs_sampler <- function(b0, B0, a0, d0, X, y, m = 100) {
    n <- length(y)
    p <- ncol(X)
    XTX <- t(X) %*% X
    yty <- t(y) %*% y
    bhat <- solve(XTX, t(X) %*% y)

    beta_samples <- vector("list", m)
    sigma2_samples <- numeric(m)
    sigma2 <- runif(1, 0.5, 1.5)
    for (i in 1:m) {
        Bn <- solve(solve(B0) + XTX)
        bn <- Bn %*% (solve(B0, b0) + t(X) %*% y)
        beta <- mvtnorm::rmvnorm(1, bn, sigma2 * Bn)

        an <- a0 + n
        dn <- d0 + yty + t(b0) %*% solve(B0, b0) - t(bn) %*% solve(Bn, bn)
        sigma2 <- MCMCpack::rinvgamma(1, an / 2, dn / 2)

        beta_samples[[i]] <- beta
        sigma2_samples[[i]] <- sigma2
    }

    beta_samples <- do.call(rbind, beta_samples)
    list(data = data, beta = beta_samples, sigma2 = sigma2_samples)
}


posterior_mean_analytical <- function(b0, B0, X, y) {
    XTX <- t(X) %*% X
    bhat <- solve(XTX, t(X) %*% y)
    Bn <- solve(solve(B0) + XTX)
    bn <- Bn %*% (solve(B0, b0) + XTX %*% bhat)
    bn
}


AD_gibbs_sampler <- function(b0, B0, a0, d0, X, y, m) {
    ADtools::auto_diff(
        function(b0, B0, a0, d0, X, y, m = 100) {
            n <- length(y)
            p <- ncol(X)
            XTX <- t(X) %*% X
            yty <- t(y) %*% y
            bhat <- solve(XTX, t(X) %*% y)

            beta_samples <- vector("list", m)
            sigma2_samples <- vector("list", m)
            sigma2 <- runif(1, 0.5, 1.5)
            for (i in 1:m) {
                Bn <- solve(solve(B0) + XTX)
                bn <- Bn %*% (solve(B0) %*% b0 + t(X) %*% y)
                beta <- ADtools::rmvnorm0(1, bn, sigma2 * Bn)

                an <- a0 + n
                dn <- as.vector(d0 + yty + t(b0) %*% solve(B0) %*% b0 -
                                    t(bn) %*% solve(Bn) %*% bn)
                # sigma2 <- rinvgamma(an / 2, dn / 2)  # See Note [1]
                sigma2 <- 1 / ADtools::rgamma0(1, an / 2, 2 / dn)

                beta_samples[[i]] <- beta
                sigma2_samples[[i]] <- sigma2
            }

            beta_samples <- do.call(rbind, beta_samples)
            sigma2_samples <- do.call(rbind, lapply(sigma2_samples, as.matrix))
            list(beta = beta_samples, sigma2 = sigma2_samples)
        },
        wrt = "b0",
        at = list(b0 = b0, B0 = B0, a0 = a0, d0 = d0, X = X, y = y, m = m)
    )
    # Note [1]
    #     X ~ Gamma(k, theta)     <=> 1 / X ~ InvGamma(k, 1 / theta)
    # 1 / X ~ Gamma(k, theta)     <=>     X ~ InvGamma(k, 1 / theta)
    # 1 / X ~ Gamma(k, 1 / theta) <=>     X ~ InvGamma(k, theta)
}




# Utility functions
map_row <- function(X, f, ...) {
    map(1:nrow(X), \(i) f(X[i, ]), ...)
}

map_col <- function(X, f, ...) {
    map(1:ncol(X), \(j) f(X[, j]), ...)
}

iter_col <- function(X, f, ...) {
    map(1:ncol(X), \(j) f(X[, j], j), ...)
}

map <- function(X, f, ...) {
    Map(f, X, ...)
}

reduce <- function(X, f, ...) {
    Reduce(f, X, ...)
}

set_lapply <- function(parallel = FALSE, mc.cores = 4, ...) {
    if (!parallel) return(lapply)
    return(purrr::partial(parallel::mclapply, mc.cores = mc.cores, ...))
}


join_data_with_attr <- function(join_data, join_attr, keys) {
    function(x, y) {
        result <- join_data(x, y)
        for (key in keys) {
            attr(result, key) <- join_attr(attr(x, key), attr(y, key))
        }
        result
    }
}

mse <- \(x, y) mean((x - y)^2)

# (Copied from `bayesian_linear_regression.R` so that this file is standalone)
# This function computes the posterior statistics from the MCMC draws
mcmc_posterior <- function(draws, stat_fun = mean, burn_ins) {
    if (missing(burn_ins)) burn_ins <- round(length(draws) / 10)
    stat_fun(tail(draws, -burn_ins))
}

# This function is for computing the variance of the posterior mean estimator
mcmc_slicing <- function(draws, stat_fun, burn_ins, block_size = 50) {
    if (missing(burn_ins)) {
        burn_ins <- round(length(draws) / 10)
    }
    draws <- tail(draws, -burn_ins)

    sapply(1:block_size, function(start_index) {
        s <- seq(start_index, length(draws), by = block_size)
        stat_fun(draws[s])
    })
}

mcmc_sliced_mean_with_sd <- function(draws, ...) {
    x <- mcmc_slicing(draws, stat_fun = mean, ...)
    result <- mean(x)
    attr(result, "sigma") <- sd(x)
    result
}




# Check that the MCMC procedure is implemented correctly
check_gibbs_sampler <- function() {
    message("Check that the MCMC procedure is implemented correctly")
    set.seed(1234)

    # DGP
    # y ~ N(X * beta, sigma2 * In)
    # beta | sigma2 ~ N(b0, sigma2 * B0)
    # sigma2 ~ IG(a0 / 2, d0 / 2)
    p <- 2
    n <- 50
    b0 <- rnorm(p)
    B0 <- diag(runif(p, 0.5, 1.5))
    a0 <- 30
    d0 <- 3
    data <- sim_bayesian_linear_regression(n, p, b0, B0, a0, d0)
    X <- data$X
    y <- data$y
    print(data)

    # 100 random points in the prior input space
    message("Showing the 100 MSEs of the MCMC posterior estimates against the analytic formula")
    sapply(1:100, function(i) {
        random_b0 <- rnorm(p)
        random_B0 <- diag(runif(p, 0.5, 1.5))
        mcmc_est <- gibbs_sampler(random_b0, random_B0, a0, d0, X, y)$beta |> colMeans()
        formula_est <- posterior_mean_analytical(random_b0, random_B0, X, y) |> c()
        mse(mcmc_est, formula_est)
    })
}
