#' Active learning with surrogate function
#'
#' @param X A list of evaluation points.
#' @param f A function; the objective function.
#' @param surrogate A function; the surrogate class of object.
#' @param max_iter A positive integer; the maximum iterations.
#' @param tol A numeric vector of length two; the mean-squared-error and 
#' out-of-sample-predictive variance tolerance for convergence.
#' @param persistence A positive integer; the number of times the tolerance 
#' should be reached before convergence. Use -1 to disable this option.
#' @param consecutive A positive integer; the number of consecutive times the tolerance 
#' should be reached before convergence. Use -1 to disable this option.
#' @param parallel TRUE or FALSE; whether to use parallel computing.
#' @param mc.cores A positive integer; the number of cores to use for parallel
#' computing. Only works when `parallel = TRUE`.
#' @param restart The fitted model from previous run.
#' @param options Optional argument to pass to \code{optim}.
#' @param callback Optional function to be called after each iteration of update;
#' this lets users examine the fitting process.
#' @param ... Optional arguments to pass to the surrogate function.
#'
#' @return A model object returned by the surrogate function.
#'
#' @examples
#' \dontrun{
#' library(gptools2)
#' s <- seq(-5, 5, length.out = 3)
#' X <- as.matrix(expand.grid(s, s))
#' f <- function(x) sin(x[1]) + x[2]
#'
#' model <- active_learning(
#'     X, f, sigma = 1e-3,
#'     max_iter = 100, tol = c(0.1, 0.01)
#' )
#'
#' # In-sample fit
#' pred_y <- predict_gp(model$model, X)
#' y <- map_row(X, f)
#' compare(y, pred_y$mean)
#'
#' # Out-of-sample performance
#' new_X <- as.matrix(expand.grid(runif(10, -4, 4), runif(10, -4, 4)))
#' new_y <- map_row(new_X, f)
#' new_pred_y <- predict_gp(model$model, new_X)
#' compare(new_y, new_pred_y$mean)
#' }
#' @export
active_learning <- function(X, f, surrogate = gp_surrogate$new, 
                            max_iter, tol, persistence = 1, consecutive = -1,  
                            parallel = FALSE, mc.cores = 1, 
                            restart, options = list(), 
                            callback, ...) {
    if (missing(restart)) {
        model <- surrogate(X, f, parallel = parallel, mc.cores = mc.cores, 
                           options = options, ...)
    } else {
        model <- restart
    }
    if (missing(tol)) tol <- max(model$sigma)
    if (any(is.na(tol))) tol[is.na(tol)] <- max(model$sigma)
    cat("Maximum number of iterations:", max_iter, "\n")
    cat("Convergence threshold:", tol, "\n")
    n_in_a_row <- 0
    
    if (max_iter > 0) {
        cat("In-sample fit | Out-of-sample predictive variance", "\n")
        for (i in 1:max_iter) {
            model$update(parallel = parallel, mc.cores = mc.cores, options = options)
            cat(model$measure, "\n")
            if (all(model$measure < tol)) {
                persistence <- persistence - 1
                n_in_a_row <- n_in_a_row + 1
                if (persistence <= 0 && n_in_a_row >= consecutive) {
                    message("Converged after ", i, " iterations.")
                    break
                }
            } else {
                n_in_a_row <- 0
            }
            if (!missing(callback)) {
                callback(model)
            }
        }
    }
    model
}


#' Gaussian Process surrogate function
#' @export
gp_surrogate <- R6::R6Class("gp_surrogate", public = list(
    #' @field X A n-by-p numeric matrix where n is the number of data points,
    #' p is the number of dimensions; the predictor variables.
    X = NA,

    #' @field ys A n-by-1 column vector; the response variable.
    ys = NA,
    
    #' @field sigma A n-by-1 column vector or a scalar; the observational noise.
    sigma = NA,

    #' @field true_f A function; the function of which the surrogate is seeked.
    true_f = NA,

    #' @field model A variable for storing the fitted model.
    model = NA,

    #' @field sample_fun A function; the function to sample the next set of
    #' points to evaluate using the surrogate function. By default, it samples
    #' from the bounding hypercube of the variable `X`.
    sample_fun = NA,

    #' @field select_fun A function; the function to select the next point
    #' to evaluate using the true function. By default, it picks the point that
    #' has the highest predictive variance.
    select_fun = NA,

    #' @field in_sample_fit The in-sample fit of the fitted model to the
    #' training data X.
    in_sample_fit = NA,

    #' @field next_to_consider The next set of points to evaluate using
    #' the surrogate function.
    next_to_consider = NA,

    #' @field out_of_sample_uncertainty The out-of-sample uncertainty
    #' associated with `next_to_consider`; the predictive variance.
    out_of_sample_uncertainty = NA,

    #' @field next_to_evaluate The next point to evaluate using the true function.
    next_to_evaluate = NA,

    #' @field measure The summary measure of the fitted model.
    measure = NA,

    #' @field debug TRUE or FALSE, whether to use the debug mode.
    debug = NA,

    #' @description Constructor of the Gaussian Process surrogate model
    #' @param X A numeric matrix; the training data.
    #' @param f A function; the function of which the surrogate is seeked.
    #' @param kernel The kernel function of the Gaussian Process.
    #' @param sigma The regularisation parameter of the Gaussian Process.
    #' @param sample_fun sample_fun A function; the function to sample the next set of
    #' points to evaluate using the surrogate function.
    #' @param select_fun select_fun A function; the function to select the next point
    #' to evaluate using the true function.
    #' @param debug TRUE or FALSE, whether to use the debug mode.
    #' @param sample_n A positive integer; the number of samples to draw in each iteration.
    #' Note that this parameter is used only when `sample_fun` is not provided.
    #' @param parallel TRUE or FALSE; whether to use parallel computing.
    #' @param mc.cores A positive integer; the number of cores to use for parallel
    #' computing. Only works when `parallel = TRUE`.
    #' @param options Optional argument to pass to \code{optim}.
    #' @param ... Additional parameters to pass to \link{gp}.
    initialize = function(X, f, kernel = squared_exponential(),
                          sigma = 0, sample_fun, select_fun,
                          debug = FALSE, sample_n = 20, 
                          parallel = FALSE, mc.cores = 1, 
                          options = list(), ...) {
        if (missing(sample_fun)) {
            cube_X <- find_bounding_hypercube(X)
            sample_fun <- function() sample_from_hypercube(cube_X, sample_n)
        }
        if (missing(select_fun)) {
            select_fun <- point_of_max_score
        }
        
        self$sample_fun <- sample_fun
        self$select_fun <- select_fun
        self$true_f <- f
        self$debug <- self$debug

        self$X <- X
        self$ys <- map_row(X, f, parallel = parallel, mc.cores = mc.cores)
        private$sigma_0 <- sigma
        self$sigma <- extract_observational_noise(self$ys, private$sigma_0)
        self$model <- gp(self$X, self$ys, kernel, self$sigma, options)

        self$measure_fit_and_risk()
    },

    #' @description Update a fitted model with active learning
    #' @param parallel TRUE or FALSE; whether to use parallel computing.
    #' @param mc.cores A positive integer; the number of cores to use for parallel
    #' computing. Only works when `parallel = TRUE`.
    #' @param options Optional argument to pass to \code{optim}
    update = function(parallel = FALSE, mc.cores = 1, options) {
        self$X <- rbind(self$X, self$next_to_evaluate)
        new_y <- map_row(self$next_to_evaluate, self$true_f, 
                         parallel = parallel, mc.cores = mc.cores)
        self$ys <- c(self$ys, new_y)
        self$sigma <- c(self$sigma,
                        extract_observational_noise(new_y, private$sigma_0, verbose = FALSE))
        self$model <- gp(self$X, self$ys, self$model$kernel, self$sigma, options)

        self$measure_fit_and_risk()
    },

    #' @description Evaluate the fitted model using the in-sample fit and out-of-sample variance
    measure_fit_and_risk = function() {
        py <- predict_gp(self$model, self$X)
        self$in_sample_fit <- norm(self$ys - py$mean, "2")

        self$next_to_consider <- self$sample_fun()
        self$out_of_sample_uncertainty <- diag(predict_gp(self$model, self$next_to_consider)$covariance)
        self$next_to_evaluate <- self$select_fun(self$next_to_consider, self$out_of_sample_uncertainty)

        self$measure <- c(self$in_sample_fit, sqrt(max(self$out_of_sample_uncertainty)))
    }
), list(
    sigma_0 = NA
))


#' Map f over each row of a matrix x
#'
#' @param x A matrix.
#' @param f A function that takes each row of x as input.
#' @param ... Additional parameters to pass to `set_lapply`; should contain
#' `parallel` and `mc.cores` at the minimum, and other parameters for `mclapply`
#' can be supplied as well.
#' @export
map_row <- function(x, f, ...) {
    lapply2 <- set_lapply(...)
    fx <- lapply2(1:NROW(x), function(i) f(x[i, ]))
    res <- combine_value(fx)
    attributes(res) <- combine_attr(fx)
    res
}

set_lapply <- function(parallel = FALSE, mc.cores = 2, ...) {
    ifelse(!parallel,
           lapply,
           function(...) mclapply(..., mc.cores = mc.cores))
}

combine_value <- function(xs) do.call(c, xs)

combine_attr <- function(xs) {
    if (length(xs) < 1) return(NULL)
    keys <- names(attributes(xs[[1]]))
    res <- lapply(keys, function(key) sapply(xs, function(x) attr(x, key)))
    setNames(res, keys)
}


point_of_max_score <- function(X, scores) {
    X[which.max(scores), , drop = FALSE]
}


find_bounding_hypercube <- function(X) {
    apply(X, 2, range)
}


sample_from_hypercube <- function(X, n) {
    stopifnot(nrow(X) == 2)
    apply(X, 2, function(rg) runif(n, rg[1], rg[2]))
}


extract_observational_noise <- function(ys, sigma_default, verbose = TRUE) {
    obs_noise <- get_sigma(ys)
    if (length(obs_noise) > 0 && !all(is.na(obs_noise)) && verbose) {
        message("Detect observational noise estimates provided by f. I will take advantage of this information.")
    }
    
    obs_noise[is.na(obs_noise)] <- sigma_default
    obs_noise
}

get_sigma <- function(y) attr(y, "sigma") %||% NA

`%||%` <- function(x, y) if (is.null(x)) y else x
