148 lines
5.3 KiB
R
148 lines
5.3 KiB
R
#' Specialized version of GMLM for the tensor normal model
|
|
#'
|
|
#' The underlying algorithm is an ``iterative (block) coordinate descent'' method
|
|
#'
|
|
#' @export
|
|
gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
|
|
max.iter = 100L, proj.betas = NULL, proj.Omegas = NULL, logger = NULL,
|
|
cond.threshold = 25, eps = 1e-6
|
|
) {
|
|
# rearrange `X`, `F` such that the last axis enumerates observations
|
|
if (!missing(sample.axis)) {
|
|
axis.perm <- c(seq_along(dim(X))[-sample.axis], sample.axis)
|
|
X <- aperm(X, axis.perm)
|
|
F <- aperm(F, axis.perm)
|
|
sample.axis <- length(dim(X))
|
|
}
|
|
|
|
# Get problem dimensions (observation dimensions)
|
|
dimX <- head(dim(X), -1)
|
|
dimF <- head(dim(F), -1)
|
|
modes <- seq_along(dimX)
|
|
|
|
# Ensure the Omega and beta projections lists are lists
|
|
if (!is.list(proj.Omegas)) {
|
|
proj.Omegas <- rep(NULL, length(modes))
|
|
}
|
|
if (!is.list(proj.betas)) {
|
|
proj.betas <- rep(NULL, length(modes))
|
|
}
|
|
|
|
|
|
### Phase 1: Computing initial values (`dim<-` ensures we have an "array")
|
|
meanX <- `dim<-`(rowMeans(X, dims = length(dimX)), dimX)
|
|
meanF <- `dim<-`(rowMeans(F, dims = length(dimF)), dimF)
|
|
|
|
# center X and F
|
|
X <- X - as.vector(meanX)
|
|
F <- F - as.vector(meanF)
|
|
|
|
# initialize Omega estimates as mode-wise, unconditional covariance estimates
|
|
Sigmas <- Map(diag, dimX)
|
|
Omegas <- Map(diag, dimX)
|
|
|
|
# Per mode covariance directions
|
|
# Note: (the directions are transposed!)
|
|
dirsX <- Map(function(Sigma) {
|
|
SVD <- La.svd(Sigma, nu = 0)
|
|
sqrt(SVD$d) * SVD$vt
|
|
}, mcov(X, sample.axis, center = FALSE))
|
|
dirsF <- Map(function(Sigma) {
|
|
SVD <- La.svd(Sigma, nu = 0)
|
|
sqrt(SVD$d) * SVD$vt
|
|
}, mcov(F, sample.axis, center = FALSE))
|
|
|
|
# initialization of betas ``covariance direction mappings``
|
|
betas <- betas.init <- Map(function(dX, dF) {
|
|
s <- min(ncol(dX), nrow(dF))
|
|
crossprod(dX[1:s, , drop = FALSE], dF[1:s, , drop = FALSE])
|
|
}, dirsX, dirsF)
|
|
|
|
# Residuals
|
|
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
|
|
|
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
|
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
|
# `Omega <- Reduce(kronecker, rev(Omegas))`.
|
|
det.Omega <- sum(mapply(function(Omega) {
|
|
sum(log(eigen(Omega, TRUE, TRUE)$values))
|
|
}, Omegas) / dimX)
|
|
# Initial value of the log-likelihood (scaled and constants dropped)
|
|
loss <- mean(R * mlm(R, Omegas)) - det.Omega
|
|
|
|
# invoke the logger
|
|
if (is.function(logger)) do.call(logger, list(
|
|
iter = 0L, betas = betas, Omegas = Omegas,
|
|
resid = R, loss = loss
|
|
))
|
|
|
|
|
|
### Phase 2: (Block) Coordinate Descent
|
|
for (iter in seq_len(max.iter)) {
|
|
|
|
# update every beta (in random order)
|
|
for (j in sample.int(length(betas))) {
|
|
FxB_j <- mlm(F, betas[-j], modes[-j])
|
|
FxSB_j <- mlm(FxB_j, Sigmas[-j], modes[-j])
|
|
betas[[j]] <- Omegas[[j]] %*% t(solve(mcrossprod(FxSB_j, FxB_j, j), mcrossprod(FxB_j, X, j)))
|
|
# Project `betas` onto their manifold
|
|
if (is.function(proj_j <- proj.betas[[j]])) {
|
|
betas[[j]] <- proj_j(betas[[j]])
|
|
}
|
|
}
|
|
|
|
# Residuals
|
|
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
|
|
|
# Covariance Estimates
|
|
Sigmas <- mcov(R, sample.axis, center = FALSE)
|
|
|
|
# Computing `Omega_j`s, the j'th mode presition matrices, in conjunction
|
|
# with regularization of the j'th mode covariance estimate `Sigma_j`
|
|
for (j in seq_along(Sigmas)) {
|
|
# Compute min and max eigen values
|
|
min_max <- range(eigen(Sigmas[[j]], TRUE, TRUE)$values)
|
|
# The condition is approximately `kappa(Sigmas[[j]]) > cond.threshold`
|
|
if (min_max[2] > cond.threshold * min_max[1]) {
|
|
Sigmas[[j]] <- Sigmas[[j]] + diag(0.2 * min_max[2], nrow(Sigmas[[j]]))
|
|
}
|
|
# Compute (unconstraint but regularized) Omega_j as covariance inverse
|
|
Omegas[[j]] <- solve(Sigmas[[j]])
|
|
# Project Omega_j to the Omega_j's manifold
|
|
if (is.function(proj_j <- proj.Omegas[[j]])) {
|
|
Omegas[[j]] <- proj_j(Omegas[[j]])
|
|
# Reverse computation of `Sigma_j` as inverse of `Omega_j`
|
|
# Order of projecting Omega_j and then recomputing Sigma_j is importent
|
|
Sigmas[[j]] <- solve(Omegas[[j]])
|
|
}
|
|
}
|
|
|
|
# store last loss
|
|
loss.last <- loss
|
|
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
|
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
|
# `Omega <- Reduce(kronecker, rev(Omegas))`.
|
|
det.Omega <- sum(mapply(function(Omega) {
|
|
sum(log(eigen(Omega, TRUE, TRUE)$values))
|
|
}, Omegas) / dimX)
|
|
# Compute new loss
|
|
loss <- mean(R * mlm(R, Omegas)) - det.Omega
|
|
|
|
# invoke the logger
|
|
if (is.function(logger)) do.call(logger, list(
|
|
iter = iter, betas = betas, Omegas = Omegas, resid = R, loss = loss
|
|
))
|
|
|
|
# check the break consition
|
|
if (abs(loss.last - loss) < eps * abs(loss.last)) {
|
|
break
|
|
}
|
|
}
|
|
|
|
|
|
structure(
|
|
list(eta1 = mlm(meanX, Sigmas), betas = betas, Omegas = Omegas),
|
|
betas.init = betas.init
|
|
)
|
|
}
|