103 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			103 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Per mode (axis) MLE
 | 
						|
#'
 | 
						|
#' @param sample.axis index of the sample mode, a.k.a. observation axis index
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
kpir.mle <- function(X, Fy, sample.axis = 1L, center = TRUE,
 | 
						|
    max.iter = 50L, max.init.iter = 10L, eps = sqrt(.Machine$double.eps),
 | 
						|
    logger = NULL
 | 
						|
) {
 | 
						|
    ### Step 0: Setup/Initialization
 | 
						|
    if (!is.array(Fy)) {
 | 
						|
        # scalar response case (add new axis of size 1)
 | 
						|
        dim(Fy) <- ifelse(seq_along(dim(X)) == sample.axis, dim(X)[sample.axis], 1L)
 | 
						|
    }
 | 
						|
    # Check dimensions and matching of axis (tensor order)
 | 
						|
    stopifnot(exprs = {
 | 
						|
        length(dim(X)) == length(dim(Fy))
 | 
						|
        dim(X)[sample.axis] == dim(Fy)[sample.axis]
 | 
						|
    })
 | 
						|
    # warn about occurence of an axis without reduction
 | 
						|
    if (any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])) {
 | 
						|
        warning("Degenerate case 'any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])'")
 | 
						|
    }
 | 
						|
 | 
						|
    # extract dimensions (for easier handling as local variables)
 | 
						|
    modes <- seq_along(dim(X))[-sample.axis]    # predictor axis indices
 | 
						|
    n <- dim(X)[sample.axis]                    # observation count (scalar)
 | 
						|
    p <- dim(X)[-sample.axis]                   # predictor dimensions (vector)
 | 
						|
    r <- length(dim(X)) - 1L                    # predictor rank (tensor order)
 | 
						|
 | 
						|
    if (center) {
 | 
						|
        # Means for X and Fy (a.k.a. sum elements over the sample axis)
 | 
						|
        meanX <- apply(X, modes, mean, simplify = TRUE)
 | 
						|
        meanFy <- apply(Fy, modes, mean, simplify = TRUE)
 | 
						|
        # Center both X and Fy
 | 
						|
        X <- sweep(X, modes, meanX)
 | 
						|
        Fy <- sweep(Fy, modes, meanFy)
 | 
						|
    } else {
 | 
						|
        meanX <- meanFy <- NA
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
    ### Step 1: Initial values
 | 
						|
    ls.fit <- kpir.ls(X, Fy, sample.axis = sample.axis, center = FALSE,
 | 
						|
        max.iter = max.init.iter, eps = eps, logger = logger)
 | 
						|
    alphas <- ls.fit$alphas
 | 
						|
    Deltas <- ls.fit$Deltas
 | 
						|
    # compute residuals
 | 
						|
    R <- X - mlm(Fy, alphas, modes)
 | 
						|
    # Compute covariance inverses
 | 
						|
    Delta.invs <- Map(solve, Deltas)
 | 
						|
    # multiply Deltas with alphas
 | 
						|
    Delta.inv.alphas <- Map(`%*%`, Delta.invs, alphas)
 | 
						|
 | 
						|
 | 
						|
    ### Step 2: Iterative Updating
 | 
						|
    for (iter in seq_len(max.iter)) {
 | 
						|
 | 
						|
        # Invoke logger for previous iterate
 | 
						|
        if (is.function(logger)) {
 | 
						|
            logger("mle", iter - 1L, alphas, Deltas)
 | 
						|
        }
 | 
						|
 | 
						|
        # random order cyclic updating
 | 
						|
        for (j in sample(2 * r)) {
 | 
						|
            # toggle between updating alpha j <= r and Delta j > r
 | 
						|
            if (j <= r) {
 | 
						|
                # Update `alpha_j`
 | 
						|
                XxDi  <- mlm(X,  Delta.invs[-j],       modes[-j])
 | 
						|
                Fxa   <- mlm(Fy,           alphas[-j], modes[-j])
 | 
						|
                FxDia <- mlm(Fy, Delta.inv.alphas[-j], modes[-j])
 | 
						|
                alphas[[j]] <- mcrossprod(XxDi, Fxa, modes[j]) %*%
 | 
						|
                    solve(mcrossprod(FxDia, Fxa, modes[j]))
 | 
						|
 | 
						|
                # Recompute Residuals (with updated alpha)
 | 
						|
                R <- X - mlm(Fy, alphas, modes)
 | 
						|
            } else {
 | 
						|
                j <- j - r  # map from [r + 1; 2 r] to [1; r]
 | 
						|
 | 
						|
                # Update `Delta_j`
 | 
						|
                Deltas[[j]] <- (p[j] / (n * prod(p))) *
 | 
						|
                    mcrossprod(mlm(R, Delta.invs[-j], modes[-j]), R, modes[j])
 | 
						|
 | 
						|
                # Recompute `Delta_j^-1`
 | 
						|
                Delta.invs[[j]] <- solve(Deltas[[j]])
 | 
						|
                # as well as `Delta_j^-1 alpha_j`
 | 
						|
                Delta.inv.alphas[[j]] <- Delta.invs[[j]] %*% alphas[[j]]
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        # TODO: add some kind of break condition
 | 
						|
    }
 | 
						|
 | 
						|
    # Before returning, call logger for the final iteration
 | 
						|
    if (is.function(logger)) {
 | 
						|
        logger("mle", iter, alphas, Deltas)
 | 
						|
    }
 | 
						|
 | 
						|
    list(alphas = structure(alphas, names = as.character(modes)),
 | 
						|
         Deltas = structure(Deltas, names = as.character(modes)),
 | 
						|
         meanX = meanX, meanFy = meanFy)
 | 
						|
}
 |