#' Gradient Descent based Tensor Predictors method with Nesterov Accelerated
#' Momentum
#'
#' @export
kpir.momentum <- function(X, Fy, shape = c(dim(X)[-1], dim(Fy[-1])),
    max.iter = 500L, max.line.iter = 50L, step.size = 1e-3,
    nesterov.scaling = function(a, t) 0.5 * (1 + sqrt(1 + (2 * a)^2)),
    max.init.iter = 20L, init.method = c("ls", "vlp"),
    eps = .Machine$double.eps,
    logger = NULL
) {

    # Check if X and Fy have same number of observations
    stopifnot(nrow(X) == NROW(Fy))
    n <- nrow(X)                        # Number of observations

    # Get and check predictor dimensions
    if (length(dim(X)) == 2L) {
        stopifnot(!missing(shape))
        stopifnot(ncol(X) == prod(shape[1:2]))
        p <- as.integer(shape[1])       # Predictor "height"
        q <- as.integer(shape[2])       # Predictor "width"
    } else if (length(dim(X)) == 3L) {
        p <- dim(X)[2]
        q <- dim(X)[3]
        dim(X) <- c(n, p * q)
    } else {
        stop("'X' must be a matrix or 3-tensor")
    }

    # Get and check response dimensions
    if (!is.array(Fy)) {
        Fy <- as.array(Fy)
    }
    if (length(dim(Fy)) == 1L) {
        k <- r <- 1L
        dim(Fy) <- c(n, 1L)
    } else if (length(dim(Fy)) == 2L) {
        stopifnot(!missing(shape))
        stopifnot(ncol(Fy) == prod(shape[3:4]))
        k <- as.integer(shape[3])       # Response functional "height"
        r <- as.integer(shape[4])       # Response functional "width"
    } else if (length(dim(Fy)) == 3L) {
        k <- dim(Fy)[2]
        r <- dim(Fy)[3]
        dim(Fy) <- c(n, k * r)
    } else {
        stop("'Fy' must be a vector, matrix or 3-tensor")
    }


    ### Step 1: (Approx) Least Squares initial estimate
    init.method <- match.arg(init.method)
    if (init.method == "ls") {
        dim(X) <- c(n, p, q)
        dim(Fy) <- c(n, k, r)
        ls <- kpir.ls(X, Fy, max.iter = max.init.iter, sample.axis = 1L, eps = eps)
        c(beta0, alpha0) %<-% ls$alphas
        dim(X) <- c(n, p * q)
        dim(Fy) <- c(n, k * r)
    } else { # Van Loan and Pitsianis
        # solution for `X = Fy B' + epsilon`
        cpFy <- crossprod(Fy)               # TODO: Check/Test and/or replace
        if (n <= k * r || qr(cpFy)$rank < k * r) {
            # In case of under-determined system replace the inverse in the normal
            # equation by the Moore-Penrose Pseudo Inverse
            B <- t(matpow(cpFy, -1) %*% crossprod(Fy, X))
        } else {
            # Compute OLS estimate by the Normal Equation
            B <- t(solve(cpFy, crossprod(Fy, X)))
        }

        # Decompose `B = alpha x beta` into `alpha` and `beta`
        c(alpha0, beta0) %<-% approx.kronecker(B, c(q, r), c(p, k))
    }

    # Compute residuals
    resid <- X - tcrossprod(Fy, kronecker(alpha0, beta0))

    # Covariance estimate
    Delta <- crossprod(resid) / n

    # Transformed Residuals (using `matpow` as robust inversion algo,
    # uses Moore-Penrose Pseudo Inverse in case of singular `Delta`)
    resid.trans <- resid %*% matpow(Delta, -1)

    # Evaluate negative log-likelihood
    loss <- 0.5 * (n * log(det(Delta)) + sum(resid.trans * resid))

    # Call history callback (logger) before the first iterate
    if (is.function(logger)) {
        logger(0L, loss, alpha0, beta0, Delta, NA)
    }


    ### Step 2: MLE with LS solution as starting value
    a0 <- 0
    a1 <- 1
    alpha1 <- alpha0
    beta1 <- beta0

    # main descent loop
    no.nesterov <- TRUE
    for (iter in seq_len(max.iter)) {
        if (no.nesterov) {
            # without extrapolation as fallback
            S.alpha <- alpha1
            S.beta  <- beta1
        } else {
            # extrapolation using previous direction
            S.alpha <- alpha1 + ((a0 - 1) / a1) * (alpha1 - alpha0)
            S.beta  <-  beta1 + ((a0 - 1) / a1) * ( beta1 -  beta0)
        }

        # Extrapolated Residuals, Covariance and transformed Residuals
        resid <- X - tcrossprod(Fy, kronecker(S.alpha, S.beta))
        Delta <- crossprod(resid) / n
        resid.trans <- resid %*% matpow(Delta, -1)

        # Sum over kronecker prod by observation (Face-Splitting Product)
        KR <- colSums(rowKronecker(Fy, resid.trans))
        dim(KR) <- c(p, q, k, r)

        # (Nesterov) `alpha` Gradient
        R.alpha <- aperm(KR, c(2, 4, 1, 3))
        dim(R.alpha) <- c(q * r, p * k)
        grad.alpha <- c(R.alpha %*% c(S.beta))

        # (Nesterov) `beta` Gradient
        R.beta <- aperm(KR, c(1, 3, 2, 4))
        dim(R.beta) <- c(p * k, q * r)
        grad.beta <- c(R.beta %*% c(S.alpha))

        # Backtracking line search (Armijo type)
        # The `inner.prod` is used in the Armijo break condition but does not
        # depend on the step size.
        inner.prod <- sum(grad.alpha^2) + sum(grad.beta^2)

        # backtracking loop
        for (delta in step.size * 0.618034^seq.int(0L, length.out = max.line.iter)) {
            # Update `alpha` and `beta` (note: add(+), the gradients are already
            # pointing into the negative slope direction of the loss cause they are
            # the gradients of the log-likelihood [NOT the negative log-likelihood])
            alpha.temp <- S.alpha + delta * grad.alpha
            beta.temp  <- S.beta  + delta * grad.beta

            # Update Residuals, Covariance and transformed Residuals
            resid <- X - tcrossprod(Fy, kronecker(alpha.temp, beta.temp))
            Delta <- crossprod(resid) / n
            resid.trans <- resid %*% matpow(Delta, -1)

            # Evaluate negative log-likelihood
            loss.temp <- 0.5 * (n * log(det(Delta)) + sum(resid.trans * resid))

            # Armijo line search break condition
            if (loss.temp <= loss - 0.1 * delta * inner.prod) {
                break
            }
        }

        # Call logger (invoce history callback)
        if (is.function(logger)) {
            logger(iter, loss.temp, alpha.temp, beta.temp, Delta, delta)
        }

        # Ensure descent
        if (loss.temp < loss) {
            alpha0 <- alpha1
            alpha1 <- alpha.temp
            beta0  <- beta1
            beta1  <- beta.temp

            # check break conditions (in descent case)
            if (mean(abs(alpha1)) + mean(abs(beta1)) < eps) {
                break   # basically, estimates are zero -> stop
            }
            if (inner.prod < eps * (p * q + r * k)) {
                break   # mean squared gradient is smaller than epsilon -> stop
            }
            if (abs(loss.temp - loss) < eps) {
                break   # decrease is too small (slow) -> stop
            }

            loss  <- loss.temp
            no.nesterov <- FALSE    # always reset
        } else if (!no.nesterov) {
            no.nesterov <- TRUE     # retry without momentum
            next
        } else {
            break       # failed even without momentum -> stop
        }

        # update momentum scaling
        a0 <- a1
        a1 <- nesterov.scaling(a1, iter)

        # Set next iter starting step.size to line searched step size
        # (while allowing it to encrease)
        step.size <- 1.618034 * delta

    }

    list(loss = loss, alpha = alpha1, beta = beta1, Delta = Delta)
}