#' Per mode (axis) MLE #' #' @param sample.axis index of the sample mode, a.k.a. observation axis index #' #' @export kpir.mle <- function(X, Fy, sample.axis = 1L, center = TRUE, max.iter = 50L, max.init.iter = 10L, eps = sqrt(.Machine$double.eps), logger = NULL ) { ### Step 0: Setup/Initialization if (!is.array(Fy)) { # scalar response case (add new axis of size 1) dim(Fy) <- ifelse(seq_along(dim(X)) == sample.axis, dim(X)[sample.axis], 1L) } # Check dimensions and matching of axis (tensor order) stopifnot(exprs = { length(dim(X)) == length(dim(Fy)) dim(X)[sample.axis] == dim(Fy)[sample.axis] }) # warn about occurence of an axis without reduction if (any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])) { warning("Degenerate case 'any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])'") } # extract dimensions (for easier handling as local variables) modes <- seq_along(dim(X))[-sample.axis] # predictor axis indices n <- dim(X)[sample.axis] # observation count (scalar) p <- dim(X)[-sample.axis] # predictor dimensions (vector) r <- length(dim(X)) - 1L # predictor rank (tensor order) if (center) { # Means for X and Fy (a.k.a. sum elements over the sample axis) meanX <- apply(X, modes, mean, simplify = TRUE) meanFy <- apply(Fy, modes, mean, simplify = TRUE) # Center both X and Fy X <- sweep(X, modes, meanX) Fy <- sweep(Fy, modes, meanFy) } else { meanX <- meanFy <- NA } ### Step 1: Initial values ls.fit <- kpir.ls(X, Fy, sample.axis = sample.axis, center = FALSE, max.iter = max.init.iter, eps = eps, logger = logger) alphas <- ls.fit$alphas Deltas <- ls.fit$Deltas # compute residuals R <- X - mlm(Fy, alphas, modes) # Compute covariance inverses Delta.invs <- Map(solve, Deltas) # multiply Deltas with alphas Delta.inv.alphas <- Map(`%*%`, Delta.invs, alphas) ### Step 2: Iterative Updating for (iter in seq_len(max.iter)) { # Invoke logger for previous iterate if (is.function(logger)) { logger("mle", iter - 1L, alphas, Deltas) } # random order cyclic updating for (j in sample(2 * r)) { # toggle between updating alpha j <= r and Delta j > r if (j <= r) { # Update `alpha_j` XxDi <- mlm(X, Delta.invs[-j], modes[-j]) Fxa <- mlm(Fy, alphas[-j], modes[-j]) FxDia <- mlm(Fy, Delta.inv.alphas[-j], modes[-j]) alphas[[j]] <- mcrossprod(XxDi, Fxa, modes[j]) %*% solve(mcrossprod(FxDia, Fxa, modes[j])) # Recompute Residuals (with updated alpha) R <- X - mlm(Fy, alphas, modes) } else { j <- j - r # map from [r + 1; 2 r] to [1; r] # Update `Delta_j` Deltas[[j]] <- (p[j] / (n * prod(p))) * mcrossprod(mlm(R, Delta.invs[-j], modes[-j]), R, modes[j]) # Recompute `Delta_j^-1` Delta.invs[[j]] <- solve(Deltas[[j]]) # as well as `Delta_j^-1 alpha_j` Delta.inv.alphas[[j]] <- Delta.invs[[j]] %*% alphas[[j]] } } # TODO: add some kind of break condition } # Before returning, call logger for the final iteration if (is.function(logger)) { logger("mle", iter, alphas, Deltas) } list(alphas = structure(alphas, names = as.character(modes)), Deltas = structure(Deltas, names = as.character(modes)), meanX = meanX, meanFy = meanFy) }