add: extended dist.subspace to support list of kronecker product components as well as needing less memory,

add: if for cond.threshold if non-finite, allows to disable regularization,
fix: TSIR's mode-wise sample covariance scaling factor
This commit is contained in:
Daniel Kapla 2025-05-05 14:34:26 +02:00
parent e2e8d19a0a
commit c4a93d25fa
2 changed files with 37 additions and 29 deletions

View File

@ -1,7 +1,7 @@
#' Subspace distance #' Subspace distance
#' #'
#' @param A,B Basis matrices as representations of elements of the Grassmann #' @param As,Bs Basis matrices or list of Kronecker components of basis matrices
#' manifold. #' as representations of elements of the Grassmann manifold.
#' @param is.ortho Boolean to specify if \eqn{A} and \eqn{B} are semi-orthogonal. #' @param is.ortho Boolean to specify if \eqn{A} and \eqn{B} are semi-orthogonal.
#' If false, the projection matrices are computed as #' If false, the projection matrices are computed as
#' \deqn{P_A = A (A' A)^{-1} A'} #' \deqn{P_A = A (A' A)^{-1} A'}
@ -14,36 +14,41 @@
#' subspaces of different dimensions" <arXiv:1407.0900> #' subspaces of different dimensions" <arXiv:1407.0900>
#' #'
#' @export #' @export
dist.subspace <- function (A, B, is.ortho = FALSE, normalize = FALSE, dist.subspace <- function (As, Bs, is.ortho = FALSE, normalize = FALSE,
tol = sqrt(.Machine$double.eps) tol = sqrt(.Machine$double.eps)
) { ) {
if (!is.matrix(A)) A <- as.matrix(A) As <- if (is.list(As)) Map(as.matrix, As) else list(as.matrix(As))
if (!is.matrix(B)) B <- as.matrix(B) Bs <- if (is.list(Bs)) Map(as.matrix, Bs) else list(as.matrix(Bs))
if (!is.ortho) { if (!is.ortho) {
As <- Map(function(A) {
qrA <- qr(A, tol) qrA <- qr(A, tol)
if (qrA$rank < ncol(A)) { if (qrA$rank < ncol(A)) {
A <- qr.Q(qrA)[, abs(diag(qr.R(qrA))) > tol, drop = FALSE] qr.Q(qrA)[, abs(diag(qr.R(qrA))) > tol, drop = FALSE]
} else { } else {
A <- qr.Q(qrA) qr.Q(qrA)
} }
}, As)
Bs <- Map(function(B) {
qrB <- qr(B, tol) qrB <- qr(B, tol)
if (qrB$rank < ncol(B)) { if (qrB$rank < ncol(B)) {
B <- qr.Q(qrB)[, abs(diag(qr.R(qrB))) > tol, drop = FALSE] qr.Q(qrB)[, abs(diag(qr.R(qrB))) > tol, drop = FALSE]
} else { } else {
B <- qr.Q(qrB) qr.Q(qrB)
} }
}, Bs)
} }
PA <- tcrossprod(A, A) rankA <- prod(sapply(As, ncol))
PB <- tcrossprod(B, B) rankB <- prod(sapply(Bs, ncol))
if (normalize) { c <- if (normalize) {
rankSum <- ncol(A) + ncol(B) rankSum <- rankA + rankB
c <- 1 / sqrt(max(1, min(rankSum, 2 * nrow(A) - rankSum))) sqrt(1 / max(1, min(rankSum, 2 * prod(sapply(As, nrow)) - rankSum)))
} else { } else {
c <- 1 1
} }
c * norm(PA - PB, type = "F") s <- prod(mapply(function(A, B) sum(crossprod(A, B)^2), As, Bs))
c * sqrt(max(0, rankA + rankB - 2 * s))
} }

View File

@ -105,12 +105,15 @@ gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
# Computing `Omega_j`s, the j'th mode presition matrices, in conjunction # Computing `Omega_j`s, the j'th mode presition matrices, in conjunction
# with regularization of the j'th mode covariance estimate `Sigma_j` # with regularization of the j'th mode covariance estimate `Sigma_j`
for (j in seq_along(Sigmas)) { for (j in seq_along(Sigmas)) {
# Regularize Covariances
if (is.finite(cond.threshold)) {
# Compute min and max eigen values # Compute min and max eigen values
min_max <- range(eigen(Sigmas[[j]], TRUE, TRUE)$values) min_max <- range(eigen(Sigmas[[j]], TRUE, TRUE)$values)
# The condition is approximately `kappa(Sigmas[[j]]) > cond.threshold` # The condition is approximately `kappa(Sigmas[[j]]) > cond.threshold`
if (min_max[2] > cond.threshold * min_max[1]) { if (min_max[2] > cond.threshold * min_max[1]) {
Sigmas[[j]] <- Sigmas[[j]] + diag(0.2 * min_max[2], nrow(Sigmas[[j]])) Sigmas[[j]] <- Sigmas[[j]] + diag(0.2 * min_max[2], nrow(Sigmas[[j]]))
} }
}
# Compute (unconstraint but regularized) Omega_j as covariance inverse # Compute (unconstraint but regularized) Omega_j as covariance inverse
Omegas[[j]] <- solve(Sigmas[[j]]) Omegas[[j]] <- solve(Sigmas[[j]])
# Project Omega_j to the Omega_j's manifold # Project Omega_j to the Omega_j's manifold