#' Inference of the LAID model using Gibbs Sampler
#' @param data A dataframe; the source data.
#' @param d0 Prior location parameters (Normal prior).
#' @param D0 Prior covariance (Normal prior).
#' @param Wv Prior degrees of freedom (Wishart prior).
#' @param WEsc0 Prior scale matrix (Wishart prior).
#' @param Sig Starting value of the covariance of the normal prior.
#' @param num_iter Number of MCMC Iterations.
#' @param method Extra argument for sampling functions, 'base'  or 'inv_tf'.
#' @export
# This function removes the for-loops for some gain in efficiency, and it is also
# needed to let rWishart to handle dual input
Gibbs_sampler_2 <- function(data, d0, D0, Wv, WEsc0, Sig, num_iter = 1e4, method = "inv_tf") {
    Ch <- format_SUR_data(data)
    s <- with(data, cbind(w1, w2, w3, w4))
    N <- dim(Ch)[3]

    Ch_mat <- do.call(rbind, purrr::map(1:N, ~Ch[, , .x]))
    tCh_mat <- t(do.call(cbind, purrr::map(1:N, ~Ch[, , .x])))
    s_vec <- as.numeric(t(s))
    blk_t_dual_ind <- block_transpose_dual_ind(dim(Ch)[3], dim(Ch)[2], dim(Ch)[1])

    d0 <- as.matrix(d0)
    d <- d0
    # Helper variables
    inv_D0 <- solve(D0)
    inv_D0_times_d0 <- inv_D0 %*% d0
    inv_WEsc0 <- solve(WEsc0)
    inv_Sig <- solve(Sig)
    # Storage variables
    delta <- vector("list", num_iter)
    Sigma <- vector("list", num_iter)
    tic <- Sys.time() # Time

    pb <- txtProgressBar(min = 1, max = num_iter, initial = 1, style = 3)
    for (i in 1:num_iter) {
        tCh_times_inv_Sig <- block_transpose(tCh_mat %*% inv_Sig, 18, blk_t_dual_ind)
        AuxSum <- tCh_times_inv_Sig %*% Ch_mat
        AuxSum1 <- tCh_times_inv_Sig %*% s_vec

        SigDel <- round(solve(inv_D0 + AuxSum), 15)
        MeanDel <- SigDel %*% (AuxSum1 + inv_D0_times_d0)
        d <- MeanDel + chol0(SigDel) %*% rnorm(length(d0))
        # d <- as.vector(rmvnorm0(1, MeanDel, SigDel))

        AuxSum2 <- tcrossprod(matrix(s_vec - Ch_mat %*% d, nrow = 4, byrow = F))
        WScale <- solve(inv_WEsc0 + AuxSum2)
        vNew <- Wv + N
        inv_Sig <- rWishart0(vNew, WScale, method = method)
        Sig <- solve(inv_Sig)

        delta[[i]] <- d
        Sigma[[i]] <- Sig
        setTxtProgressBar(pb, value = i)
    }

    toc <- Sys.time() # Time
    print(toc - tic)  # Total time
    list(delta = delta, Sigma = Sigma)
}

# One-way block transpose (long column to long row)
block_transpose_dual_ind <- function(num_block, block_nrow, block_ncol) {
    row_ind <- seq(block_nrow)
    col_ind <- seq(block_ncol)
    block_seq <- seq(num_block)

    blk_t_dual <- function() {
        do.call(c, purrr::map(block_seq, ~blk_to_entries(.x)))
    }

    blk_to_entries <- function(block_rid) {
        rind <- row_ind + (block_rid - 1) * block_nrow
        cind <- col_ind
        ind_map(rind, cind, block_nrow * num_block)
    }

    ind_map <- function(i, j, nr) {
        do.call(
            c,
            purrr::map(j, function(j) {
                purrr::map_dbl(i, function(i) {
                    i + (j - 1) * nr
                })
            })
        )
    }

    blk_t_dual()
}

block_transpose <- function(x, block_nrow, dual_ind) {
    num_block <- nrow(x) / block_nrow
    block_seq <- seq(num_block)
    row_ind <- seq(block_nrow)

    block_ncol <- ifelse("dual" %in% class(x), ncol(x@x), ncol(x))
    if (missing(dual_ind)) {
        dual_ind <- block_transpose_dual_ind(num_block, block_nrow, block_ncol)
    }

    blk_t <- function(x, block_nrow) {
        y <- x[dual_ind]
        dim(y) <- c(block_nrow, length(x) / block_nrow)
        y
    }

    if ("dual" %in% class(x)) {
        x@x <- blk_t(x@x, block_nrow)
        x@dx <- x@dx[dual_ind, ]
        x
    } else {
        blk_t(x, block_nrow)
    }
}


#' Takes the diagonal vector instead of a matrix for 'D0'
#' @export
Gibbs_sampler_2b <- function(data, d0, log_D0, Wv, WEsc0, Sig, num_iter = 1e4, method = "inv_tf") {
    diag_vec <- function(x) diag(as.vector(x))
    Gibbs_sampler_2(data, d0, diag_vec(exp(log_D0)), Wv, WEsc0, Sig, num_iter, method)
}


#' Derivative of Gibbs sampler
#' @param data A data frame; the source data.
#' @param d0 Prior location parameters (Normal prior).
#' @param log_D0 Log of the diagonal of the prior covariance matrix (Normal prior).
#' @param Wv Prior degrees of freedom (Wishart prior).
#' @param WEsc0 Prior scale matrix (Wishart prior).
#' @param Sig Starting value of the covariance of the normal prior.
#' @param num_iter Number of MCMC Iterations.
#' @param method Extra argument for sampling functions, 'base'  or 'inv_tf'.
#' @export
AD_Gibbs <- function(data, d0, log_D0, Wv, WEsc0, Sig, num_iter = 1e4, method = "inv_tf") {
    ADtools::auto_diff(
        # Gibbs_sampler_2,
        Gibbs_sampler_2b,
        wrt = c("d0", "log_D0"),
        at = list(data = data, d0 = d0, log_D0 = log_D0,
                  Wv = Wv, WEsc0 = WEsc0, Sig = Sig,
                  num_iter = num_iter, method = method)
    )
}

extract_AD <- function(AD_res) {
    list(
        delta = t(do.call(cbind, AD_res$delta)),
        Sigma = do.call(abind::abind, list(AD_res$Sigma, rev.along = 3))
    )
}

tidy_AD <- function(AD_res, burn_ins) {
    if (missing(burn_ins)) {
        burn_ins <- round(length(AD_res$delta) / 10)
    }
    burn <- function(x) tail(x, -burn_ins)

    list(
        delta = purrr::reduce(burn(AD_res$delta), `+`) / length(burn(AD_res$delta)),
        Sigma = purrr::reduce(burn(AD_res$Sigma), `+`) / length(burn(AD_res$Sigma))
    )
}

get_Jacobian <- function(tidy_res) {
    rbind(
        set_rownames(
            tidy_res$delta@dx,
            paste("delta", seq(nrow(tidy_res$delta@dx)), sep = "_")
        ),
        set_rownames(
            tidy_res$Sigma@dx,
            paste("Sigma", matrix_labels(nrow(tidy_res$Sigma@x)), sep = "_")
        )
    )
}

matrix_labels <- function(n) {
    s <- seq(n)
    apply(expand.grid(s, s), 1, paste, collapse = "")
}
