#' Specialized version of GMLM for the tensor normal model #' #' The underlying algorithm is an ``iterative (block) coordinate descent'' method #' #' @export gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)), max.iter = 100L, proj.betas = NULL, proj.Omegas = NULL, logger = NULL, cond.threshold = 25, eps = 1e-6 ) { # rearrange `X`, `F` such that the last axis enumerates observations if (!missing(sample.axis)) { axis.perm <- c(seq_along(dim(X))[-sample.axis], sample.axis) X <- aperm(X, axis.perm) F <- aperm(F, axis.perm) sample.axis <- length(dim(X)) } # Get problem dimensions (observation dimensions) dimX <- head(dim(X), -1) dimF <- head(dim(F), -1) modes <- seq_along(dimX) # Ensure the Omega and beta projections lists are lists if (!is.list(proj.Omegas)) { proj.Omegas <- rep(NULL, length(modes)) } if (!is.list(proj.betas)) { proj.betas <- rep(NULL, length(modes)) } ### Phase 1: Computing initial values (`dim<-` ensures we have an "array") meanX <- `dim<-`(rowMeans(X, dims = length(dimX)), dimX) meanF <- `dim<-`(rowMeans(F, dims = length(dimF)), dimF) # center X and F X <- X - as.vector(meanX) F <- F - as.vector(meanF) # initialize Omega estimates as mode-wise, unconditional covariance estimates Sigmas <- Map(diag, dimX) Omegas <- Map(diag, dimX) # Per mode covariance directions # Note: (the directions are transposed!) dirsX <- Map(function(Sigma) { SVD <- La.svd(Sigma, nu = 0) sqrt(SVD$d) * SVD$vt }, mcov(X, sample.axis, center = FALSE)) dirsF <- Map(function(Sigma) { SVD <- La.svd(Sigma, nu = 0) sqrt(SVD$d) * SVD$vt }, mcov(F, sample.axis, center = FALSE)) # initialization of betas ``covariance direction mappings`` betas <- betas.init <- Map(function(dX, dF) { s <- min(ncol(dX), nrow(dF)) crossprod(dX[1:s, , drop = FALSE], dF[1:s, , drop = FALSE]) }, dirsX, dirsF) # Residuals R <- X - mlm(F, Map(`%*%`, Sigmas, betas)) # Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)` # which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where # `Omega <- Reduce(kronecker, rev(Omegas))`. det.Omega <- sum(mapply(function(Omega) { sum(log(eigen(Omega, TRUE, TRUE)$values)) }, Omegas) / dimX) # Initial value of the log-likelihood (scaled and constants dropped) loss <- mean(R * mlm(R, Omegas)) - det.Omega # invoke the logger if (is.function(logger)) do.call(logger, list( iter = 0L, betas = betas, Omegas = Omegas, resid = R, loss = loss )) ### Phase 2: (Block) Coordinate Descent for (iter in seq_len(max.iter)) { # update every beta (in random order) for (j in sample.int(length(betas))) { FxB_j <- mlm(F, betas[-j], modes[-j]) FxSB_j <- mlm(FxB_j, Sigmas[-j], modes[-j]) betas[[j]] <- Omegas[[j]] %*% t(solve(mcrossprod(FxSB_j, FxB_j, j), mcrossprod(FxB_j, X, j))) # Project `betas` onto their manifold if (is.function(proj_j <- proj.betas[[j]])) { betas[[j]] <- proj_j(betas[[j]]) } } # Residuals R <- X - mlm(F, Map(`%*%`, Sigmas, betas)) # Covariance Estimates Sigmas <- mcov(R, sample.axis, center = FALSE) # Computing `Omega_j`s, the j'th mode presition matrices, in conjunction # with regularization of the j'th mode covariance estimate `Sigma_j` for (j in seq_along(Sigmas)) { # Compute min and max eigen values min_max <- range(eigen(Sigmas[[j]], TRUE, TRUE)$values) # The condition is approximately `kappa(Sigmas[[j]]) > cond.threshold` if (min_max[2] > cond.threshold * min_max[1]) { Sigmas[[j]] <- Sigmas[[j]] + diag(0.2 * min_max[2], nrow(Sigmas[[j]])) } # Compute (unconstraint but regularized) Omega_j as covariance inverse Omegas[[j]] <- solve(Sigmas[[j]]) # Project Omega_j to the Omega_j's manifold if (is.function(proj_j <- proj.Omegas[[j]])) { Omegas[[j]] <- proj_j(Omegas[[j]]) # Reverse computation of `Sigma_j` as inverse of `Omega_j` # Order of projecting Omega_j and then recomputing Sigma_j is importent Sigmas[[j]] <- solve(Omegas[[j]]) } } # store last loss loss.last <- loss # Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)` # which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where # `Omega <- Reduce(kronecker, rev(Omegas))`. det.Omega <- sum(mapply(function(Omega) { sum(log(eigen(Omega, TRUE, TRUE)$values)) }, Omegas) / dimX) # Compute new loss loss <- mean(R * mlm(R, Omegas)) - det.Omega # invoke the logger if (is.function(logger)) do.call(logger, list( iter = iter, betas = betas, Omegas = Omegas, resid = R, loss = loss )) # check the break consition if (abs(loss.last - loss) < eps * abs(loss.last)) { break } } structure( list(eta1 = mlm(meanX, Sigmas), betas = betas, Omegas = Omegas), betas.init = betas.init ) }