JMpredSurv <- function(object, dat, Survdat = NULL, u, heuristic = T)
{
  betal <- object$coefficients$long
  if (length(object$coefficients$shared) == 1 && object$coefficients$shared == 0) {
    betals <- 0
  } else {
    betals <- object$coefficients$shared[names(object$coefficients$shared) != "time_effect"]
  }
  betat <- object$coefficients$shared["time_effect"]
  if(is.na(betat)) betat <- 0; names(betat) <- NULL
  betas <- object$coefficients$surv[-length(object$coefficients$surv)]
  alpha <- object$coefficients$surv[length(object$coefficients$surv)]
  lambda <- object$coefficients$lambda
  
  if (heuristic) {
    rdmInd <- all(unlist(lapply(object$randomest, function(x) all(x == 0))))
    if (rdmInd) {
      gammas <- c(0, 0) # if no variation in random effects optimization for indiv raneff is pointless
    } else {
      gammas <- tryCatch(
        {
          mode_gammas <- optim(par = c(0, 0), fn = logpb, object = object,
                                   dat = dat, Survdat = Survdat,
                                   control = list(fnscale = -1),
                                   method = "BFGS"
                                   # method = "Nelder-Mead"
                                   )
          mode_gammas$par
        },
        error = function(e) c(NA, NA)
        # error = function(e) c(0, 0)
      )
     # mode_gammas <- tryCatch(
     #   {
     #     optim(par = c(0, 0), fn = logpb, object = object,
     #           dat = dat, Survdat = Survdat,
     #           control = list(fnscale = -1),
     #           method = "BFGS"
     #           # method = "Nelder-Mead"
     #           )
     #   },
     #     error = function(e) e
     #   )
    
     # gammas <- mode_gammas$par
      
    }
  } else {
    gammas <- c(0, 0)
  }
  gamma0 <- gammas[1]
  gamma1 <- gammas[2]
  
  # new data
  if (class(dat) != "list"){
    dat <- datPrep(long = object$long, shared = object$shared, surv = object$surv,
                   timeVar = "T_long", idVar = "id", data = dat, survData = Survdat)
  }
  
  Xs <- dat$Xs
  Xls <- dat$Xls
  id <- dat$id
  T_long <- dat$T_long
  delta <- dat$delta
  Ti <- dat$T_surv
  
  last <- unlist(tapply(T_long, id, function(x) x == max(x)))
  
  if (!is.null(Xls)) {Xls_un = as.matrix(Xls[last==1, ,drop = F])} else {Xls_un <- 0}
  if (is.null(Xs)) {Xs = 0}
  
  # return(Xs)
  Sdat <- list("Xs" = Xs, "Xls_un" = Xls_un)
  theta <- list("alpha" = alpha, "lambda" = lambda, "betas" = betas, "betals" = betals, "betat" = betat)
  
  S.u <- S(T_surv = u, gammas = gammas, dat = Sdat, theta = theta)
  S.Ti <- S(T_surv = Ti, gammas = gammas, dat = Sdat, theta = theta)
  
  pred <- data.frame(u = u, surv = S.u/S.Ti)
  return(pred)
}

# newdat only for one person at a time
logpb <- function(gammas, object, dat, Survdat = NULL) {
  gamma0 <- gammas[1]
  gamma1 <- gammas[2]
  
  # est coef
  betal <- object$coefficients$long
  if (length(object$coefficients$shared) == 1 && object$coefficients$shared == 0) {
    betals <- 0
  } else {
    betals <- object$coefficients$shared[names(object$coefficients$shared) != "time_effect"]
  }
  betat <- object$coefficients$shared["time_effect"]
  if(is.na(betat)) betat <- 0; names(betat) <- NULL
  betas <- object$coefficients$surv[-length(object$coefficients$surv)]
  alpha <- object$coefficients$surv[length(object$coefficients$surv)]
  lambda <- object$coefficients$lambda
  sigma2 <- object$coefficients$sigma2
  D <- var(Reduce("cbind", object$randomest))
  
  # new data
  if (class(dat) != "list"){
    dat <- datPrep(long = object$long, shared = object$shared, surv = object$surv, 
                   timeVar = "T_long", idVar = "id", data = dat, survData = Survdat)
  }
  
  y <- dat$y
  Xl <- dat$Xl
  Xs <- dat$Xs
  Xls <- dat$Xls
  id <- dat$id
  T_long <- dat$T_long
  T_surv <- max(dat$T_long) 
  delta <- dat$delta
  
  last <- unlist(tapply(T_long, id, function(x) x == max(x)))
  
  if (!is.null(Xls)) {Xls_un = as.matrix(Xls[last==1, ,drop = F])}
  if (is.null(Xl)) {Xl = 0}
  if (is.null(Xs)) {Xs = 0}
  if (is.null(Xls)) {Xls = 0; Xls_un = 0}
  
  Xr1 <- as.matrix(Matrix::bdiag(lapply(table(id), function(x) rep(1, x))))
  Xr2 <- Xr1*T_long
  Xr <- cbind(Xr1, Xr2)
  # XrSurv <- cbind(diag(1, nrow(Xs)), diag(1, nrow(Xs)) * as.vector(T_surv))
  # if (gamma0 == 0) {gamma0 <- rep(0, ncol(Xr1))}
  # if (gamma1 == 0) {gamma1 <- rep(0, ncol(Xr2))}
  # rm(Xr1, Xr2)
  
  pb <- mvtnorm::dmvnorm(c(gamma0, gamma1), mean = rep(0, 2), sigma = D, 
                         log = TRUE) 
  
  etal = as.vector(cbind(1, Xl)%*%betal)
  etas = as.vector(Xs%*%betas)
  etals = as.vector(Xr%*%c(gamma0, gamma1) + as.vector(Xls%*%betals) + T_long*betat)
  etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un%*%betals) + betat*T_surv)
  
  py.b <- sum(dnorm(y, mean = etal + etals, sd = sqrt(sigma2), log = T))
  
  time.eff.ind <- (betat + gamma1) != 0
  psurv.b <- ifelse(time.eff.ind,
                    lambda*exp(etas)*(exp(alpha*etals_un) - exp(alpha*(etals_un - (gamma1 + betat)*T_surv)))/(alpha*(betat+gamma1)),
                    lambda*exp(etas)*exp(alpha*etals_un)*T_surv)
  b.mode <- pb + py.b + psurv.b
  return(b.mode)
}

S <- function(T_surv, gammas, theta, dat)
{ 
gamma0 <- gammas[1]
gamma1 <- gammas[2]

alpha = theta$alpha
lambda = theta$lambda 
betas = theta$betas
betals = theta$betals
betat = theta$betat

Xs = dat$Xs
Xls_un = dat$Xls_un

etas = as.vector(Xs %*% betas)
etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un %*% betals) + betat*T_surv)

time.eff.ind <- (betat + gamma1) != 0
integral <- as.matrix(ifelse(time.eff.ind,
                             lambda*exp(etas)*(exp(alpha*etals_un) - exp(alpha*(etals_un - (gamma1 + betat)*T_surv)))/(alpha*(betat + gamma1)),
                             lambda*exp(etas)*exp(alpha*etals_un)*T_surv)
)
Sprob <- exp(-integral)
return(Sprob)
}

datPrep <- function(long, shared, surv, data, survData = NULL, timeVar, idVar) {
  long <- update(long, ~ . - 1)
  mf <- model.frame(long, data = data)
  y <- model.response(mf)
  Xl <- model.matrix(long, data = data)
  if (ncol(Xl) == 0) { Xl <-  NULL}
  shared <- update(shared, ~ . - 1)
  Xls <- model.matrix(shared, data = data)
  if (ncol(Xls) == 0) { Xls <-  NULL}
  if (is.null(survData)) {
    survData <- data[!duplicated(data[, idVar]), ]
  }
  survdat <- survData
  smf <- model.frame(surv, survdat)
  T_surv <- model.response(smf)[,1]
  delta <- model.response(smf)[,2]
  Xs <- model.matrix(surv, data = survdat)[,-1, drop = FALSE]
  if (ncol(Xs) == 0) { Xs <-  NULL}
  
  if (timeVar %in% colnames(Xls)) {time.effect <- TRUE; Xls <- Xls[, colnames(Xls) != timeVar, drop = F]} else {time.effect <- FALSE}
  id <- data[, idVar]
  T_long <- data[, timeVar]
  structure(list(y = y, Xl = Xl, Xs = Xs, Xls = Xls, T_long = T_long, T_surv = T_surv,
                 delta = delta, id = id, time.effect = time.effect))
}
