# # Code variant based on Colins JMboostE.R, corrected surv.risk function and 
# # gradients to incorporate individuals without any time effect correctly. Code changed in S2, S3 and surv.risk

#' Main function to carry out boosting for a joint model!
#'
#' @param y vector containing the longitudinal outcome
#' @param Xl design matrix for the longitudinal part containing time-varying covariates
#' @param Xs design matrix for the survival part containing one measurement per individual and covariate
#' @param Xls design matrix for the shared part containing time-independent covariates (duplicated values per individual)
#' @param T_long longitudinal time points
#' @param T_surv observed survival times
#' @param delta censoring indicator
#' @param id id vector labeling the longitudinal outcome/design matrices with corresponding individuals
#' @param alpha starting value for the association parameter
#' @param lambda starting value for the baseline hazard
#' @param int starting value for the intercept in the longitudinal submodel
#' @param time.effect logical, whether a shared time.effect shall be included
#' @param mstop_i number of boosting iterations per boosting step
#' @param nyi step length per boosting step


JMboost.this <- function(long, shared, surv, data, survData = NULL, timeVar, idVar, control = list(mstop_l, mstop_s, mstop_ls)) {
  
  if(is.null(names(control)) || !all(c("mstop_l", "mstop_s", "mstop_ls") %in% names(control))) {
    stop("Stopping iterations for each gradient needed.")
  }
  
  long <- update(long, ~ . - 1)
  mf <- model.frame(long, data = data)
  y <- model.response(mf)
  Xl <- model.matrix(long, data = data)
  if (ncol(Xl) == 0) { Xl <-  NULL}
  shared <- update(shared, ~ . - 1)
  Xls <- model.matrix(shared, data = data)
  if (ncol(Xls) == 0) { Xls <-  NULL}
  if (is.null(survData)) {
    survData <- data[!duplicated(data[, idVar]), ]
  }
  survdat <- survData
  smf <- model.frame(surv, survdat)
  T_surv <- model.response(smf)[,1]
  delta <- model.response(smf)[,2]
  Xs <- model.matrix(surv, data = survdat)[,-1, drop = FALSE]
  if (ncol(Xs) == 0) { Xs <-  NULL}
  
  if (timeVar %in% colnames(Xls)) {time.effect <- TRUE; Xls <- Xls[, colnames(Xls) != timeVar]} else {time.effect <- FALSE}
  id <- data[, idVar]
  T_long <- data[, timeVar]
  
  con <- list(alpha = 1, lambda = 1, mstop_l = NULL, mstop_s = NULL, mstop_ls = NULL, nyl = .1, nys = .3, nyls = .1, nyr = .1, 
              verbose = FALSE)
  con[(conArgs <- names(control))] <- control

  fit <- do.call("JMboost", list(y = y, Xl = Xl, Xs = Xs, Xls = Xls, T_long = T_long, T_surv = T_surv,
                                   delta = delta, id = id, time.effect = time.effect,
                                   mstop_l = con$mstop_l, mstop_s = con$mstop_s, mstop_ls = con$mstop_ls, verbose = con$verbose,
                                   alpha = con$alpha, lambda = con$lambda, nyl = con$nyl, nys = con$nys, nyls = con$nyls, nyr = con$nyr))
 
  coefficients <- fit[names(fit) %in% c("int", "betal", "betas", "betals", "betat", "alpha", "lambda", "sigma2")]
  names(coefficients$betal) <- colnames(Xl)
  names(coefficients$betas) <- colnames(Xs)
  names(coefficients$betals) <- colnames(Xls)
  names(coefficients$int) <- "(Intercept)"
  names(coefficients$betat) <- "time_effect"
  names(coefficients$alpha) <- "Assoc"
  coefficients$long <- if (is.null(Xl)) {with(coefficients, c(int))} else {with(coefficients, c(int, betal))}
  if(time.effect) {
    coefficients$shared <- if (is.null(Xls)) {with(coefficients, c(betat))} else {with(coefficients, c(betals, betat))}
  } else {
    coefficients$shared <- if (is.null(Xls)) {NULL} else {with(coefficients, c(betals))}
  }
  coefficients$surv <- if (is.null(Xs)) {with(coefficients, c(alpha))} else {with(coefficients, c(betas, alpha))}
  coefficients <- coefficients[c("long", "shared", "surv", "lambda", "sigma2")]

  out <- list()
  out$coefficients <- coefficients
  out$randomest <- fit[names(fit) %in% c("gamma0", "gamma1")]
  out$long <- update(long, ~ . + 1)
  out$shared <- update(shared, ~ . + 1)
  out$surv <- surv
  out$random <- as.formula(paste0("~ ",eval(timeVar)," | ", eval(idVar)))
  out$idVar <-  idVar
  out$timeVar <- timeVar
  out$control <- con
  out$trajec <- fit[names(fit) %in% c("GAMMA0", "GAMMA1", "INT", "BETAL", "BETALS", "BETAT", "BETAS", "ALPHA", "LAMBDA", "SIGMA2", "LIKE")]
  return(out)
}

long.risk = function(y, etal, etals, sigma2){sum(-dnorm(y, mean=(etal+etals), sd=sqrt(sigma2), log=TRUE))}

surv.risk = function(alpha, lambda, etas, etals_un, T_surv, delta, gamma1, betat){
  time.eff.ind <- (betat + gamma1) != 0
  integral <- as.matrix(ifelse(time.eff.ind,
                               lambda*exp(etas)*(exp(alpha*etals_un) - exp(alpha*(etals_un - (gamma1 + betat)*T_surv)))/(alpha*(betat+gamma1)),
                               lambda*exp(etas)*exp(alpha*etals_un)*T_surv)
  )
  
  risk = -sum(delta*(log(lambda) + etas + alpha*etals_un) - integral)
  return(risk)
}

mylm = function(y,x){
  X = cbind(1, x)
  trySolve <- try(solve(t(X) %*% X), silent = T)
  if(assertthat::is.error(trySolve)) {
    beta = rep(0, ncol(X))
  } else {
    beta = solve(t(X) %*% X) %*% t(X) %*% y
  }
  RSS = sum((y - X %*% beta)^2)
  return(list("int" = beta[1], "slp" = beta[2], "RSS" = RSS))
}

JMboost = function(y, Xl = NULL, Xs = NULL, Xls = NULL, delta, T_long, T_surv, id, time.effect = TRUE,
                    alpha = 1, lambda = 1, mstop_l, mstop_s, mstop_ls, nyl = .1, nys = .3, nyls = .1, nyr = .1, verbose=FALSE){

  mstop = max(mstop_l, mstop_s, mstop_ls)
  n = length(id)
  N = length(unique(id))
  
  ############### construct random effects ######################
  Xr = matrix(ncol=2*N, nrow = n, data=0)
  unid = order(unique(id))
  id = rep(unid, as.vector(table(id)))
  for(i in 1:N){
    Xr[which(id==as.character(i)),i] = 1
    Xr[which(id==as.character(i)),N+i] = T_long[which(id==as.character(i))]
  }
  XrA = Xr[,1:N]
  lambdaran = mboost:::df2lambda(XrA, 4, weights=1)[2]
  XrAt = t(XrA)
  SA = solve(XrAt%*%XrA + lambdaran*diag(N))%*%XrAt
  XrB = Xr[,-(1:N)]
  lambdaran = mboost:::df2lambda(XrB, 4, weights=1)[2]
  XrBt = t(XrB)
  SB = solve(XrBt%*%XrB + lambdaran*diag(N))%*%XrBt
  ###############################################################
  
  ### set offset fixed and random intercept according to mle
  offset = nlme:::lme(y ~ 1, random = ~ 1 | id)
  int = 0 #offset$coefficients$fixed
  gamma0 = rep(0, N)#offset$coefficients$random$id
  gamma1 = rep(0, N)
  sigma2 = offset$sigma
  # print(sigma2)
  
  ### set starting values based on chosen sets of covariates
  betal = 0
  betas = 0
  betat = 0
  betals = 0
  if(is.null(Xl)){pl = 0; Xl = 0}else{pl = ncol(Xl); betal = rep(0, pl)}
  if(is.null(Xs)){ps = 0; Xs = 0}else{ps = ncol(Xs); betas = rep(0, ps)}
  if(is.null(Xls)){
    pls = 0; Xls = 0; Xls_un = 0
  }else{
    pls = ncol(Xls)
    first = rep(FALSE, n)
    for(i in unique(id)) {
      first[which.max(id==i)] = TRUE
    }
    Xls_un = as.matrix(Xls[first==1,])
    betals = rep(0, pls)
  }
  
  ### define storing matrices/vectors
  GAMMA0 = matrix(0, ncol=mstop, nrow=N)
  GAMMA1 = matrix(0, ncol=mstop, nrow=N)
  BETAL = matrix(0, ncol=mstop, nrow=pl)
  BETAS = matrix(0, ncol=mstop, nrow=ps)
  BETALS = matrix(0, ncol=mstop, nrow=pls)
  BETAT = rep(0, mstop)
  INT = rep(0, mstop)
  ALPHA = rep(0, mstop)
  LAMBDA = rep(0, mstop)
  SIGMA2 = rep(0, mstop)
  
  for(m in 1:mstop){
    # for(m in 2:mstop){
    # for(m in 1:99){
    # 
    # m <- 1
    ###############################################################
    #### S1 #######################################################
    ###############################################################
    if(m <= mstop_l){
      etal = as.vector(int + Xl%*%betal)
      etas = as.vector(Xs%*%betas)
      etals = as.vector(Xr%*%c(gamma0, gamma1) + as.vector(Xls%*%betals) + T_long*betat)
      etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un%*%betals) + betat*T_surv)
      ###################COMPUTING THE GRADIENT######################
      u = (y - etal - etals)/sigma2
      ##################/COMPUTING THE GRADIENT######################
      fits = matrix(0, 3, pl + 1 + as.numeric(time.effect))
      if(pl>0){
        for(i in 1:pl){
          fit = mylm(u, Xl[,i])
          fits[1,i] = fit$int
          fits[2,i] = fit$slp
          fits[3,i] = fit$RSS
        }
      }else if(!time.effect){
        int = int + nyl*mean(u)
      }
      rfit = rbind(SA, SB)%*%u
      fits[3, pl+1] = sum((u-(Xr%*%rfit))^2)
      if(time.effect){
        fit = mylm(u, T_long)
        fits[1, pl+2] = fit$int
        fits[2, pl+2] = fit$slp
        fits[3, pl+2] = fit$RSS
      }
      best = which.min(fits[3,])
      if(best==pl+1){
        gamma0 = gamma0 + nyr*rfit[1:N]
        gamma1 = gamma1 + nyr*rfit[-(1:N)]
      }else if(best==pl+2){
        betat = betat + nyl*fits[2,best]
        int = int + nyl*fits[1,best]
      }else{
        betal[best] = betal[best] + nyl*fits[2,best]
        int = int + nyl*fits[1,best]
      }
    }
    INT[m] = int
    BETAT[m] = betat
    BETAL[,m] = betal
    GAMMA0[,m] = gamma0
    GAMMA1[,m] = gamma1
    
    ###############################################################
    #### S2 #######################################################
    ###############################################################
    if(m<=mstop_s && ps>0){
      etal = as.vector(int + Xl%*%betal)
      etas = as.vector(Xs%*%betas)
      etals = as.vector(Xr%*%c(gamma0, gamma1) + as.vector(Xls%*%betals) + T_long*betat)
      etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un%*%betals) + betat*T_surv)
      ###################COMPUTING THE GRADIENT######################
      time.eff.ind <- (betat + gamma1) != 0
      u <- as.matrix(ifelse(time.eff.ind,
                            delta - lambda * exp(etas) * (exp(alpha*etals_un) - exp(alpha*(etals_un - (gamma1 + betat)*T_surv)))/(alpha*(betat + gamma1)),
                            delta - lambda*exp(etas)*exp(alpha*etals_un)*T_surv)
      )
      
      ##################/COMPUTING THE GRADIENT######################
      fits = matrix(0, 3, ps)
      for(i in 1:ps){
        fit = mylm(u, Xs[,i])
        fits[1,i] = fit$int
        fits[2,i] = fit$slp
        fits[3,i] = fit$RSS
      }
      best = which.min(fits[3,])
      betas[best] = betas[best] + nys*fits[2,best]
      lambda = lambda * exp(nys*fits[1,best])
    }
    LAMBDA[m] = lambda
    BETAS[,m] = betas
    
    ###############################################################
    #### S3 #######################################################
    ###############################################################
    if(m<=mstop_ls && pls>0){
      etal = as.vector(int + Xl%*%betal)
      etas = as.vector(Xs%*%betas)
      etals = as.vector(Xr%*%c(gamma0, gamma1) + as.vector(Xls%*%betals) + T_long*betat)
      etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un%*%betals) + betat*T_surv)
      ###################COMPUTING THE GRADIENT######################
      u_l = (y - etal - etals)/sigma2
      
      time.eff.ind <- (betat + gamma1) != 0 # Does individual time effect exist?
      u_s <- as.matrix(ifelse(time.eff.ind,
                              delta*alpha - lambda*exp(etas) * (exp(alpha*etals_un) - exp(alpha*(etals_un - (gamma1 + betat)*T_surv)))/(betat + gamma1),
                              delta*alpha - alpha*lambda*exp(etas)*exp(alpha*etals_un)*T_surv)
      )
      u = c(u_l,u_s)
      ##################/COMPUTING THE GRADIENT######################
      fits = matrix(0, 3, pls)
      for(i in 1:pls){
        fit = mylm(u, c(Xls[,i], Xls_un[,i]))
        fits[1,i] = fit$int
        fits[2,i] = fit$slp
        fits[3,i] = fit$RSS
      }
      best = which.min(fits[3,])
      betals[best] = betals[best] + nyls*fits[2,best]
      int = int + nyls*fits[1,best]
      lambda = lambda*exp(nyls*alpha*fits[1,best])
    }
    INT[m] = int
    BETALS[,m] = betals
    LAMBDA[m] = lambda
    
    ###############################################################
    #### S4 #######################################################
    ###############################################################
    etal = as.vector(int + Xl%*%betal)
    etas = as.vector(Xs%*%betas)
    etals = as.vector(Xr%*%c(gamma0, gamma1) + as.vector(Xls%*%betals) + T_long*betat)
    etals_un = as.vector(gamma0 + gamma1*T_surv + as.vector(Xls_un%*%betals) + betat*T_surv)
    
    sigma2 = optimize(long.risk, y=y, etal=etal, etals=etals, interval=c(0,100))$minimum
    SIGMA2[m] <- sigma2
    
    oi.min = min(alpha - 0.1*abs(alpha), alpha - 0.1)
    oi.max = max(alpha + 0.1*abs(alpha), alpha + 0.1)
    optim.int = c(oi.min, oi.max)
    alpha = optimize(surv.risk, lambda=lambda, etals=etals_un, etas=etas, delta=delta, gamma1=gamma1,
                     betat=betat, T_surv=T_surv, interval=optim.int)$minimum
    ALPHA[m] = alpha
    
    if(verbose){
      # if(m%%1000 == 0){print(c(x,m))}
      print(m)
    }
    
  }

  structure(list(GAMMA0 = GAMMA0, GAMMA1 = GAMMA1, BETAL = BETAL, BETAS = BETAS,
                 BETALS = BETALS, BETAT=BETAT, INT=INT, ALPHA = ALPHA, LAMBDA = LAMBDA, SIGMA2 = SIGMA2, 
                 gamma0 = gamma0, gamma1 = gamma1, betal = betal, betas = betas, betals = betals,
                 betat=betat, int=int, alpha = alpha, lambda = lambda,  sigma2 = sigma2))
}

