#' Gaussian Process with derivative (SE kernel)
#'
#' @param X A numeric matrix; the data.
#' @param y A numeric vector; the data.
#' @param dy A numeric matrix; the derivative of y with respect to X.
#' If y is n x 1, X is n x p, then dy should be n by p, having the same
#' shape as X. Each column is a partial derivative, and each row is a
#' gradient.
# #' @param kernel A kernel object, which is a named list of a list of
# #' named parameters and an uninitialised kernel function.
#' @param sigma A positive number; the noise of the data.
#' @param sigma_d A positive number; the noise of the derivative.
#' Mostly needed as a mean of regularisation to combat numerical issues.
#' @param ... Optional argument to pass to \code{optim}.
#'
#' @examples
#' \dontrun{
#' # Example 1
#' library(gptools2)
#' x <-  as.matrix(runif(10, -10, 10))
#' f <- sin
#' df <- cos
#' y <- f(x)
#' dy <- df(x)
#' model <- gp_d(x, y, dy, sigma = 1e-6, sigma_d = 1e-6)
#' pred_y <- predict_gp(model, x)
#' cbind(y, pred_y$mean, err = y - pred_y$mean)
#'
#' test_x <- as.matrix(runif(10, -10, 10))
#' test_y <- f(test_x)
#' pred_y <- predict_gp_d(model, test_x)
#' cbind(test_y, pred_y$mean, err = test_y - pred_y$mean)
#'
#' plot(x, y, pch = 19, xlim = c(-10, 10))
#' lines(s <- seq(-10, 10, 0.1), f(s), lty = 2)
#' points(test_x, test_y, pch = 19, col = 'blue')
#' points(test_x, pred_y$mean, col = 'green', cex = 3)
#'
#'
#' # Example 2
#' library(gptools2)
#' X <- as.matrix(expand.grid(
#'     seq(-5, 5, length.out = 6),
#'     seq(-5, 5, length.out = 6)
#' ))
#' y <- sin(X[,1]) + X[,2]
#' dy <- cbind(cos(X[,1]), 1)
#' model <- gp_d(X, y, dy, sigma = 1e-6, sigma_d = 1e-6,
#'               control = list(trace = 3))
#' pred_y <- predict_gp_d(model, X)
#' cbind(y, pred_y$mean, err = y - pred_y$mean)
#'
#' test_X <- as.matrix(expand.grid(
#'     runif(10, -4.5, 4.5),
#'     runif(10, -4.5, 4.5)  # avoid boundary effect
#' ))
#' test_y <- sin(test_X[,1]) + test_X[,2]
#' pred_y <- predict_gp(model, test_X)
#' cbind(test_y, pred_y$mean, err = test_y - pred_y$mean)
#'
#' # Plots
#' plot(X[,1] + X[,2], y, type = 'n',
#'      xlab = "x1 + x2", main = "f(x1, x2) = sin(x1) + x2")
#' local(for (x2 in c(test_X[,2])) {
#'     x1 <- seq(-5, 5, length.out = 40)
#'     y <- sin(x1) + x2
#'     lines(x1 + x2, y, lwd = 1, col = 1 + abs(x2))
#' })
#' points(rowSums(test_X), test_y, cex = 1, col = 1 + abs(test_X[,2]), pch = 19)
#' points(rowSums(test_X), pred_y$mean, cex = 2, col = 1 + abs(test_X[,2]))
#' }
#' @export
gp_d <- function(X, y, dy, sigma = 0, sigma_d = 0, ...) {
    kernel <- squared_exponential_d()
    dy <- t(dy)
    NLL <- function(param) {
        param_named <- relist(param, kernel$param)
        # k <- do.call(kernel$kern_fun, param_named)
        k <- do.call(
            \(sigma, l) \(X) sigma^2 * exp(X / (-2 * l^2)), 
            param_named
        )
        K11_A <- kcov(X, X, k)
        I <- diag(nrow(K11_A))
        K11_A_plus_I <- K11_A + sigma^2 * I

        # dk_dx2 <- do.call(kernel$dk_dx2, param_named)
        # K11_B <- kcov_d(X, X, dk_dx2)
        D <- K11_A
        l2 <- param_named$l^2
        K11_B <- kcov_se_df_dx2(X, X, D, l2)
        K11_C <- t(K11_B)

        # d2k_dx1_dx2 <- do.call(kernel$dk2_dx1_dx2, param_named)
        # K11_D <- kcov_d(X, X, d2k_dx1_dx2)
        K11_D <- kcov_se_d2f_dx1_dx2(X, X, D, l2)
        K11_D_reg <- K11_D + sigma_d * diag(nrow(K11_D))

        K <- join(K11_A_plus_I, K11_B, K11_C, K11_D_reg)
        reg_K <- K
        y_dy <- c(y, as.numeric(dy))
        alpha <- solve(reg_K, y_dy)
        0.5 * (t(y_dy) %*% alpha + log_det(reg_K))  # ignore 0.5 * n * log(2 * pi)
    }
    parameters <- optim(par = unlist(kernel$param), NLL, ...)
    list(
        train_X = X, train_y = y, train_dy = dy,
        kernel = kernel, sigma = sigma, sigma_d = sigma_d,
        parameters = relist(parameters$par, kernel$param),
        optim_log = parameters
    )
}

join <- function(A, B,C, D) {
    rbind(cbind(A, B), cbind(C, D))
}

kernel_template_d <- function(param, f, df_dx2, d2f_dx1_dx2) {
    structure(
        list(param = param, kern_fun = f,
             dk_dx2 = df_dx2, dk2_dx1_dx2 = d2f_dx1_dx2),
        class = c("kernel_d", "list")
    )
}


#' @rdname se_kernel
#' @export
squared_exponential_d <- function(sigma = 1, l = 1) {
    SE_kernel <- function(sigma, l) {
        function(x1, x2) {
            d <- norm(x1 - x2, "2")
            sigma^2 * exp(d^2 / (-2 * l^2))
        }
    }
    kernel_template_d(
        list(sigma = sigma, l = l),
        # f
        SE_kernel,
        # df_dx2
        function(sigma, l) {
            function(x1, x2) {
                - SE_kernel(sigma, l)(x1, x2) / (-l^2) * t(x1 - x2)
            }
        },
        # d2f_dx1_dx2
        function(sigma, l) {
            function(x1, x2) {
                I <- diag(length(x1))
                xxt <- (x1 - x2) %*% t(x1 - x2)
                (I - xxt / l^2) * SE_kernel(sigma, l)(x1, x2) / l^2
            }
        }
    )
}

kcov_R <- function(X1, X2, k) {
    m0 <- matrix(0, nrow(X1), nrow(X2))
    for (i in 1:nrow(X1)) {
        for (j in 1:nrow(X2)) {
            m0[i,j] <- k(X1[i, ], X2[j, ])
        }
    }
    m0
}

kcov_d <- function(X1, X2, f) {
    do.call(rbind, map(1:nrow(X1), function(i) {
        do.call(cbind, map(1:nrow(X2), function(j) {
            f(X1[i, ], X2[j, ])
        }))
    }))
}

map <- function(x, f, ...) {
    Map(f, x, ...)
}


#' Predict using the Gaussian Process (with derivative)
#'
#' @param model Output from \code{gp}; the model object.
#' @param new_X A numeric matrix; the points to predict.
#'
#' @return A named list containing all the input and the
#' predictive mean vector and covariance matrix.
#' @export
predict_gp_d <- function(model, new_X) {
    X <- model$train_X
    y <- model$train_y
    dy <-model$train_dy
    y_dy <- c(y, as.numeric(dy))
    param <- model$parameters
    sigma <- model$sigma
    sigma_d <- model$sigma_d
    kernel <- model$kernel
    # k <- do.call(kernel$kern_fun, param)
    k <- do.call(\(sigma, l) \(X) sigma^2 * exp(X / (-2 * l^2)), 
                 param)
    
    K11_A <- kcov(X, X, k)
    I <- diag(nrow(K11_A))
    K11_A_plus_I <- K11_A + sigma^2 * I

    # dk_dx2 <- do.call(kernel$dk_dx2, param)
    # K11_B <- kcov_d(X, X, dk_dx2)
    D <- K11_A
    l2 <- param$l^2
    K11_B <- kcov_se_df_dx2(X, X, D, l2)
    K11_C <- t(K11_B)

    # d2k_dx1_dx2 <- do.call(kernel$dk2_dx1_dx2, param)
    # K11_D <- kcov_d(X, X, d2k_dx1_dx2)
    K11_D <- kcov_se_d2f_dx1_dx2(X, X, D, l2)
    K11_D_plus_I <- K11_D + sigma_d^2 * diag(nrow(K11_D))

    K_11 <- join(K11_A_plus_I, K11_B, K11_C, K11_D_plus_I)

    K_21_L <- kcov(new_X, X, k)
    # K_21_R <- kcov_d(new_X, X, dk_dx2)
    K_21_R <- kcov_se_df_dx2(new_X, X, K_21_L, l2)
    K_21 <- cbind(K_21_L, K_21_R)
    K_12 <- t(K_21)

    K_22 <- kcov(new_X, new_X, k)
    list(
        model = model,
        new_X = new_X,
        mean = K_21 %*% solve(K_11, y_dy),
        covariance = K_22 - K_21 %*% solve(K_11, K_12)
    )
}
