#include <Rcpp.h>
using namespace Rcpp;

// define both_non_NA(a, b)
inline bool both_non_NA(double a, double b) {
    return (!ISNAN(a) && !ISNAN(b));
}

// [[Rcpp::export]]
NumericMatrix kcov_C(NumericMatrix X, NumericMatrix Y) {
    int n = X.nrow(), p = X.ncol();
    int m = Y.nrow(); // q = Y.ncol();  p = q
    // allocate the output matrix
    NumericMatrix out(n, m);
    
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            double dist = 0.0;
            for (int k = 0; k < p; k++) {
                double xi = X(i, k), yj = Y(j, k);
                if (both_non_NA(xi, yj)) {
                    dist += (xi - yj) * (xi - yj);
                }
            }
            out(i, j) = dist;
        }
    }
    return out;
}

// Fine-grain optimisation for supplementary analysis
// [[Rcpp::export]]
NumericMatrix kcov_se_df_dx2(NumericMatrix X, 
                             NumericMatrix Y,
                             NumericMatrix D, 
                             double l2) {
    int nrowX = X.nrow(), p = X.ncol();
    int nrowY = Y.nrow(); // q = Y.ncol();  p = q
    // allocate the output matrix
    NumericMatrix out(nrowX, nrowY * p);
    
    for (int i = 0; i < nrowX; i++) {
        for (int j = 0; j < nrowY; j++) {
            for (int n = 0; n < p; n++) {
                double value = 0.0; 
                double xi = X(i, n), yj = Y(j, n), d = D(i, j);
                if (both_non_NA(xi, yj)) {
                    value = (xi - yj);
                }
                out(i, j * p + n) = d / l2 * value;
            }
        }
    }
    return out;
}

// [[Rcpp::export]]
NumericMatrix kcov_se_d2f_dx1_dx2(NumericMatrix X, 
                                  NumericMatrix Y,
                                  NumericMatrix D,
                                  double l2) {
    int nrowX = X.nrow(), p = X.ncol();
    int nrowY = Y.nrow(); // q = Y.ncol();  p = q
    // allocate the output matrix
    NumericMatrix out(nrowX * p, nrowY * p);
    
    for (int i = 0; i < nrowX; i++) {
        for (int j = 0; j < nrowY; j++) {
            for (int m = 0; m < p; m++) {
                for (int n = 0; n < p; n++) {
                    double value = 0.0; 
                    double xin = X(i, n), yjn = Y(j, n), 
                        xim = X(i, m), yjm = Y(j, m), 
                        d = D(i, j);
                    if (both_non_NA(xin, yjn) && both_non_NA(xim, yjm)) {
                        if (m == n) {
                            value += 1; 
                        } 
                        value += - (xim - yjm) * (xin - yjn) / l2;
                    }
                    out(i * p + m, j * p + n) = d / l2 * value;
                }
            }
        }
    }
    return out;
}
