tensor_predictors/tensorPredictors/R/gmlm_tensor_normal.R

148 lines
5.3 KiB
R

#' Specialized version of GMLM for the tensor normal model
#'
#' The underlying algorithm is an ``iterative (block) coordinate descent'' method
#'
#' @export
gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
max.iter = 100L, proj.betas = NULL, proj.Omegas = NULL, logger = NULL,
cond.threshold = 25, eps = 1e-6
) {
# rearrange `X`, `F` such that the last axis enumerates observations
if (!missing(sample.axis)) {
axis.perm <- c(seq_along(dim(X))[-sample.axis], sample.axis)
X <- aperm(X, axis.perm)
F <- aperm(F, axis.perm)
sample.axis <- length(dim(X))
}
# Get problem dimensions (observation dimensions)
dimX <- head(dim(X), -1)
dimF <- head(dim(F), -1)
modes <- seq_along(dimX)
# Ensure the Omega and beta projections lists are lists
if (!is.list(proj.Omegas)) {
proj.Omegas <- rep(NULL, length(modes))
}
if (!is.list(proj.betas)) {
proj.betas <- rep(NULL, length(modes))
}
### Phase 1: Computing initial values (`dim<-` ensures we have an "array")
meanX <- `dim<-`(rowMeans(X, dims = length(dimX)), dimX)
meanF <- `dim<-`(rowMeans(F, dims = length(dimF)), dimF)
# center X and F
X <- X - as.vector(meanX)
F <- F - as.vector(meanF)
# initialize Omega estimates as mode-wise, unconditional covariance estimates
Sigmas <- Map(diag, dimX)
Omegas <- Map(diag, dimX)
# Per mode covariance directions
# Note: (the directions are transposed!)
dirsX <- Map(function(Sigma) {
SVD <- La.svd(Sigma, nu = 0)
sqrt(SVD$d) * SVD$vt
}, mcov(X, sample.axis, center = FALSE))
dirsF <- Map(function(Sigma) {
SVD <- La.svd(Sigma, nu = 0)
sqrt(SVD$d) * SVD$vt
}, mcov(F, sample.axis, center = FALSE))
# initialization of betas ``covariance direction mappings``
betas <- betas.init <- Map(function(dX, dF) {
s <- min(ncol(dX), nrow(dF))
crossprod(dX[1:s, , drop = FALSE], dF[1:s, , drop = FALSE])
}, dirsX, dirsF)
# Residuals
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
# `Omega <- Reduce(kronecker, rev(Omegas))`.
det.Omega <- sum(mapply(function(Omega) {
sum(log(eigen(Omega, TRUE, TRUE)$values))
}, Omegas) / dimX)
# Initial value of the log-likelihood (scaled and constants dropped)
loss <- mean(R * mlm(R, Omegas)) - det.Omega
# invoke the logger
if (is.function(logger)) do.call(logger, list(
iter = 0L, betas = betas, Omegas = Omegas,
resid = R, loss = loss
))
### Phase 2: (Block) Coordinate Descent
for (iter in seq_len(max.iter)) {
# update every beta (in random order)
for (j in sample.int(length(betas))) {
FxB_j <- mlm(F, betas[-j], modes[-j])
FxSB_j <- mlm(FxB_j, Sigmas[-j], modes[-j])
betas[[j]] <- Omegas[[j]] %*% t(solve(mcrossprod(FxSB_j, FxB_j, j), mcrossprod(FxB_j, X, j)))
# Project `betas` onto their manifold
if (is.function(proj_j <- proj.betas[[j]])) {
betas[[j]] <- proj_j(betas[[j]])
}
}
# Residuals
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
# Covariance Estimates
Sigmas <- mcov(R, sample.axis, center = FALSE)
# Computing `Omega_j`s, the j'th mode presition matrices, in conjunction
# with regularization of the j'th mode covariance estimate `Sigma_j`
for (j in seq_along(Sigmas)) {
# Compute min and max eigen values
min_max <- range(eigen(Sigmas[[j]], TRUE, TRUE)$values)
# The condition is approximately `kappa(Sigmas[[j]]) > cond.threshold`
if (min_max[2] > cond.threshold * min_max[1]) {
Sigmas[[j]] <- Sigmas[[j]] + diag(0.2 * min_max[2], nrow(Sigmas[[j]]))
}
# Compute (unconstraint but regularized) Omega_j as covariance inverse
Omegas[[j]] <- solve(Sigmas[[j]])
# Project Omega_j to the Omega_j's manifold
if (is.function(proj_j <- proj.Omegas[[j]])) {
Omegas[[j]] <- proj_j(Omegas[[j]])
# Reverse computation of `Sigma_j` as inverse of `Omega_j`
# Order of projecting Omega_j and then recomputing Sigma_j is importent
Sigmas[[j]] <- solve(Omegas[[j]])
}
}
# store last loss
loss.last <- loss
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
# `Omega <- Reduce(kronecker, rev(Omegas))`.
det.Omega <- sum(mapply(function(Omega) {
sum(log(eigen(Omega, TRUE, TRUE)$values))
}, Omegas) / dimX)
# Compute new loss
loss <- mean(R * mlm(R, Omegas)) - det.Omega
# invoke the logger
if (is.function(logger)) do.call(logger, list(
iter = iter, betas = betas, Omegas = Omegas, resid = R, loss = loss
))
# check the break consition
if (abs(loss.last - loss) < eps * abs(loss.last)) {
break
}
}
structure(
list(eta1 = mlm(meanX, Sigmas), betas = betas, Omegas = Omegas),
betas.init = betas.init
)
}