## Plot functions --------------------------------------------------------------
library(plotly)

plot_manifold <- function(prior, model, add_markers = FALSE, ...) {
    create_plot_data <- function(prior, posterior) {
        setNames(data.frame(prior, posterior), c("x", "y", "z"))
    }

    posterior_mean <- predict_gp(model$model, prior)$mean
    manifold_grid <- create_plot_data(prior, posterior_mean)
    p <- plot_surface(manifold_grid, FALSE, ...)
    if (!add_markers) {
        return(p)
    }

    evaluated_points <- create_plot_data(model$X, model$ys)
    wrap_add_markers(p, evaluated_points)
}

plot_surface <- function(x, add_markers = TRUE, ...) {
    p <- surface3d(x, ...)
    if (add_markers) {
        p <- wrap_add_markers(p, x)
    }
    p
}

# A function for 3D scatterplot
# # Basic examples
# scatter3d(data = iris)
# scatter3d(data = mtcars, ylab = "some y label")
# scatter3d(data = iris, x = "Sepal.Length", y = "Sepal.Width", z = "Petal.Width")
#
# # Overlay
# scatter3d(data = mtcars) %>%
#   wrap_add_markers(data = iris)
scatter3d <- function(data, x = NULL, y = NULL, z = NULL,
                      xlab = NULL, ylab = NULL, zlab = NULL,
                      titlefont = 12, tickfont = 12, ...) {
    plot_ly(...) %>%
        wrap_add_markers(data, x, y, z) %>%
        layout(scene = list(
            xaxis = list(title = xlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            yaxis = list(title = ylab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            zaxis = list(title = zlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont))
        ))
}

wrap_add_markers <- function(p, data, x = NULL, y = NULL, z = NULL) {
    if (is.null(x)) x <- names(data)[1]
    if (is.null(y)) y <- names(data)[2]
    if (is.null(z)) z <- names(data)[3]
    add_markers(
        p,
        x = as.formula(paste("~", x)),
        y = as.formula(paste("~", y)),
        z = as.formula(paste("~", z)),
        data = data
    )
}

# A function for 3D surface
# **NOTE**: In `wrap_add_surface`, cols 1, 2 and 3 in `data` are mapped (after
# some transformation) to x, y, and z in `plotly::add_surface`. It is assumed
# `data` is obtained from some sort of `expand.grid(x, y)` (and in that order)
# with z being the corresponding value.
# # Basic examples
# g <- expand.grid(x = 1:10, y = 5:20)
# g$z <- g$x^2 + g$y^2
# g
# surface3d(g)
# surface3d(g, zlab = "some z label")
# # overlay
# surface3d(g) %>%
#   wrap_add_surface(2 * g)
surface3d <- function(data, xlab = NULL, ylab = NULL, zlab = NULL,
                      titlefont = 24, tickfont = 14, ...) {
    plot_ly(...) %>%
        wrap_add_surface(data) %>%
        layout(scene = list(
            xaxis = list(title = xlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            yaxis = list(title = ylab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont)),
            zaxis = list(title = zlab,
                         titlefont = list(size = titlefont),
                         tickfont = list(size = tickfont))
        ))
}

wrap_add_surface <- function(p, data) {
    .y <- data[[2]]

    add_surface(
        p,
        x = unique(data[[1]]),
        y = unique(.y),
        z = matrix(data[[3]], nrow = length(unique(.y)), byrow = TRUE)
    )
}
