# Reviewer comment 5: GP (grid) vs GPAD (derivative)
source("supplementary_analysis.R")


# Generate data from the DGP
library(gptools2)
library(ADtools)
SEED <- 12345
set.seed(SEED)
p <- 30
n <- 100
b0 <- runif(p, 0.8, 1.2) * sample(c(-1, 1), p, replace = TRUE)
B0 <- diag(runif(p, 0.5, 1.0))
a0 <- 30
d0 <- 3
data <- sim_bayesian_linear_regression(n, p, b0, B0, a0, d0)
X <- data$X
y <- data$y
message(sprintf("GP (grid) vs GPAD (derivative): p=%s (SEED=%s)", p, SEED))


# Generate training and testing data
n_eval <- 100

training_X <- matrix(rnorm(n_eval * p), n_eval, p)  # Vary b0
testing_X <- matrix(rnorm(n_eval * p), n_eval, p)

testing_y_truth <- testing_X |>
    map_row(\(b0) t(posterior_mean_analytical(b0, B0, X, y))) |>
    reduce(rbind)


# Fit training data with GPAD ==================================================
set.seed(SEED + 1)
tictoc::tic("Training with GPAD")

# -- Getting derivative using automatic differentiation --------
tictoc::tic("Getting derivative using automatic differentiation")
training_y_d <- training_X |>
    map_row(\(b0) AD_gibbs_sampler(b0, B0, a0, d0, X, y, m = 100)$beta)

# Extract posterior mean and sd
training_y_d_mean_with_sd <- training_y_d |>
    map(f = function(x) {
        x@x |>
            map_col(mcmc_sliced_mean_with_sd) |>
            reduce(join_data_with_attr(c, c, "sigma"))
    }) |>
    reduce(join_data_with_attr(rbind, rbind, "sigma"))

# Extract derivative
posterior_means <- training_y_d |>
    map(colMeans) |>
    map(as.matrix)  # length is equal to number of evaluation points

# Now structure the data by dimension for GP fitting
training_y_d_derivative <- 1:p |>
    map(\(i) {
        posterior_means |>
            map(\(x) x[i, 1]) |>
            reduce(rbind)
    })

# Check that the dimensions are correct
has_dim = \(x, d) all(dim(x) == d)
for (i in 1:p) {
    item <- training_y_d_derivative[[i]]
    stopifnot(has_dim(item@x, c(n_eval, 1)))
    stopifnot(has_dim(item@dx, c(n_eval, p)))
}
time_gpad_mcmc <- tictoc::toc()

# -- Fit the model with the derivative data --------
model_gpads <- 1:p |>
    map(function(j) {
        tictoc::tic(sprintf("Processed dimension %s/%s", j, p))
        model <- gp_d(training_X,
                      y = training_y_d_derivative[[j]]@x,
                      dy = training_y_d_derivative[[j]]@dx,
                      sigma = mean(attr(training_y_d_mean_with_sd, "sigma")[,j]),
                      sigma_d = 1e-3)
        tictoc::toc()
        model
    })
time_gpad <- tictoc::toc()

# -- Calculate the MSE --------
testing_y_gpad <- model_gpads |>
    map(\(model) predict_gp_d(model, testing_X)$mean) |>
    reduce(cbind)

mse_gpad <- mse(testing_y_gpad, testing_y_truth)
message("Gaussian Process(AD) MSE: ", mse_gpad)




# Run GP with increasing size of grids and compare against GPAD ================


# Compare GP with GPAD
num_grid_points <- c(100, 500, 1000, 2000, 4000)
results <- map(num_grid_points, function(n_eval) {
    message(sprintf("Processing %s evaluation points", n_eval))

    # - Fit training data with GP ----
    set.seed(SEED + 1)
    tictoc::tic("Training with GP")

    tictoc::tic("|-Running MCMC simulation")
    training_X <- matrix(rnorm(n_eval * p), n_eval, p)
    training_y <- training_X |>  # Posterior of b0
        map_row(f = function(b0) {
            # beta samples
            gibbs_sampler(b0, B0, a0, d0, X, y)$beta |>
                map_col(mcmc_sliced_mean_with_sd) |>
                reduce(join_data_with_attr(c, c, "sigma"))
        }) |>
        reduce(join_data_with_attr(rbind, rbind, "sigma"))
    time_gp_mcmc <- tictoc::toc()

    model_gps <- training_y |>
        iter_col(\(y, j) gp(training_X, y, sigma = mean(attr(training_y, "sigma")[,j])))
    time_gp <- tictoc::toc()

    # - Compare performance on testing dataset ----
    testing_y_gp <- model_gps |>
        map(\(model) predict_gp(model, testing_X)$mean) |>
        reduce(cbind)
    message("Gaussian Process MSE: ", mse(testing_y_gp, testing_y_truth))

    # - Collect and return results ----
    list(testing = list(X = testing_X,
                        y_truth = testing_y_truth,
                        y_gp = testing_y_gp,
                        y_gpad = testing_y_gpad),
         mse_gp = mse(testing_y_gp, testing_y_truth),
         mse_gpad = mse(testing_y_gpad, testing_y_truth),
         time_gp = time_gp, time_gp_mcmc = time_gp_mcmc,
         time_gpad = time_gpad, time_gpad_mcmc = time_gpad_mcmc)
})


saveRDS(results, sprintf("output/GP_vs_GPAD_p_%s_%s.RDS", p, SEED))
