2022-05-27 18:11:48 +00:00
|
|
|
#' Per mode (axis) MLE
|
|
|
|
#'
|
2022-10-06 12:25:40 +00:00
|
|
|
#' @param sample.axis index of the sample mode, a.k.a. observation axis index
|
|
|
|
#'
|
2022-05-27 18:11:48 +00:00
|
|
|
#' @export
|
2022-10-06 12:25:40 +00:00
|
|
|
kpir.mle <- function(X, Fy, sample.axis = 1L, center = TRUE,
|
|
|
|
max.iter = 50L, max.init.iter = 10L, eps = sqrt(.Machine$double.eps),
|
|
|
|
logger = NULL
|
2022-05-27 18:11:48 +00:00
|
|
|
) {
|
|
|
|
### Step 0: Setup/Initialization
|
|
|
|
if (!is.array(Fy)) {
|
|
|
|
# scalar response case (add new axis of size 1)
|
|
|
|
dim(Fy) <- ifelse(seq_along(dim(X)) == sample.axis, dim(X)[sample.axis], 1L)
|
|
|
|
}
|
|
|
|
# Check dimensions and matching of axis (tensor order)
|
|
|
|
stopifnot(exprs = {
|
|
|
|
length(dim(X)) == length(dim(Fy))
|
|
|
|
dim(X)[sample.axis] == dim(Fy)[sample.axis]
|
|
|
|
})
|
2022-10-06 12:25:40 +00:00
|
|
|
# warn about occurence of an axis without reduction
|
2022-05-27 18:11:48 +00:00
|
|
|
if (any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])) {
|
|
|
|
warning("Degenerate case 'any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])'")
|
|
|
|
}
|
|
|
|
|
|
|
|
# extract dimensions (for easier handling as local variables)
|
|
|
|
modes <- seq_along(dim(X))[-sample.axis] # predictor axis indices
|
|
|
|
n <- dim(X)[sample.axis] # observation count (scalar)
|
|
|
|
p <- dim(X)[-sample.axis] # predictor dimensions (vector)
|
2022-10-06 12:25:40 +00:00
|
|
|
r <- length(dim(X)) - 1L # predictor rank (tensor order)
|
|
|
|
|
|
|
|
if (center) {
|
|
|
|
# Means for X and Fy (a.k.a. sum elements over the sample axis)
|
|
|
|
meanX <- apply(X, modes, mean, simplify = TRUE)
|
|
|
|
meanFy <- apply(Fy, modes, mean, simplify = TRUE)
|
|
|
|
# Center both X and Fy
|
|
|
|
X <- sweep(X, modes, meanX)
|
|
|
|
Fy <- sweep(Fy, modes, meanFy)
|
|
|
|
} else {
|
|
|
|
meanX <- meanFy <- NA
|
|
|
|
}
|
2022-05-27 18:11:48 +00:00
|
|
|
|
|
|
|
|
2022-10-06 12:25:40 +00:00
|
|
|
### Step 1: Initial values
|
|
|
|
ls.fit <- kpir.ls(X, Fy, sample.axis = sample.axis, center = FALSE,
|
|
|
|
max.iter = max.init.iter, eps = eps, logger = logger)
|
|
|
|
alphas <- ls.fit$alphas
|
|
|
|
Deltas <- ls.fit$Deltas
|
|
|
|
# compute residuals
|
|
|
|
R <- X - mlm(Fy, alphas, modes)
|
|
|
|
# Compute covariance inverses
|
|
|
|
Delta.invs <- Map(solve, Deltas)
|
|
|
|
# multiply Deltas with alphas
|
|
|
|
Delta.inv.alphas <- Map(`%*%`, Delta.invs, alphas)
|
2022-05-27 18:11:48 +00:00
|
|
|
|
|
|
|
|
2022-10-06 12:25:40 +00:00
|
|
|
### Step 2: Iterative Updating
|
2022-05-27 18:11:48 +00:00
|
|
|
for (iter in seq_len(max.iter)) {
|
|
|
|
|
2022-10-06 12:25:40 +00:00
|
|
|
# Invoke logger for previous iterate
|
|
|
|
if (is.function(logger)) {
|
|
|
|
logger("mle", iter - 1L, alphas, Deltas)
|
|
|
|
}
|
|
|
|
|
|
|
|
# random order cyclic updating
|
|
|
|
for (j in sample(2 * r)) {
|
|
|
|
# toggle between updating alpha j <= r and Delta j > r
|
|
|
|
if (j <= r) {
|
|
|
|
# Update `alpha_j`
|
|
|
|
XxDi <- mlm(X, Delta.invs[-j], modes[-j])
|
|
|
|
Fxa <- mlm(Fy, alphas[-j], modes[-j])
|
|
|
|
FxDia <- mlm(Fy, Delta.inv.alphas[-j], modes[-j])
|
|
|
|
alphas[[j]] <- mcrossprod(XxDi, Fxa, modes[j]) %*%
|
|
|
|
solve(mcrossprod(FxDia, Fxa, modes[j]))
|
|
|
|
|
|
|
|
# Recompute Residuals (with updated alpha)
|
|
|
|
R <- X - mlm(Fy, alphas, modes)
|
|
|
|
} else {
|
|
|
|
j <- j - r # map from [r + 1; 2 r] to [1; r]
|
|
|
|
|
|
|
|
# Update `Delta_j`
|
|
|
|
Deltas[[j]] <- (p[j] / (n * prod(p))) *
|
|
|
|
mcrossprod(mlm(R, Delta.invs[-j], modes[-j]), R, modes[j])
|
|
|
|
|
|
|
|
# Recompute `Delta_j^-1`
|
|
|
|
Delta.invs[[j]] <- solve(Deltas[[j]])
|
|
|
|
# as well as `Delta_j^-1 alpha_j`
|
|
|
|
Delta.inv.alphas[[j]] <- Delta.invs[[j]] %*% alphas[[j]]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
# TODO: add some kind of break condition
|
|
|
|
}
|
2022-05-27 18:11:48 +00:00
|
|
|
|
2022-10-06 12:25:40 +00:00
|
|
|
# Before returning, call logger for the final iteration
|
|
|
|
if (is.function(logger)) {
|
|
|
|
logger("mle", iter, alphas, Deltas)
|
2022-05-27 18:11:48 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
list(alphas = structure(alphas, names = as.character(modes)),
|
|
|
|
Deltas = structure(Deltas, names = as.character(modes)),
|
|
|
|
meanX = meanX, meanFy = meanFy)
|
|
|
|
}
|