tensor_predictors/tensorPredictors/R/mat_mani_projections.R

121 lines
3.1 KiB
R

#' @rdname matProj
#' @export
projSym <- function(A) 0.5 * (A + t(A))
#' @rdname matProj
#' @export
projDiag <- function(A) diag(diag(A))
#' @rdname matProj
#' @export
.projBand <- function(dims, low, high) {
diag.index <- .row(dims) - .col(dims)
mask <- (diag.index <= low) & (-high <= diag.index)
function(A) A * mask
}
#' @rdname matProj
#' @export
.projSymBand <- function(dims, low, high) {
diag.index <- .row(dims) - .col(dims)
mask <- (diag.index <= low) & (-high <= diag.index)
function(A) projSym(A) * mask
}
#' @rdname matProj
#' @export
.projPSD <- function(sym = FALSE) {
if (sym) {
function(A) {
eig <- eigen(A, symmetric = TRUE)
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
}
} else {
function(A) {
eig <- eigen(0.5 * (A + t(A)), symmetric = TRUE)
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
}
}
}
#' @rdname matProj
#' @export
.projRank <- function(rank) {
force(rank)
function(A) {
rank <- min(dim(A), rank)
svdA <- La.svd(A, rank, rank)
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
}
}
#' @rdname matProj
#' @export
.projSymRank <- function(rank) {
force(rank)
function(A) {
rank <- min(dim(A), rank)
svdA <- La.svd(0.5 * (A + t(A)), rank, rank)
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
}
}
#' @rdname matProj
#' @export
projStiefel <- function(A) {
# Using a polar decomposition of `A = Q P` via SVD `A = U D V^T`. Compaired
# to a QR decomposition the polar decomposition is unique, making it "stabel".
svdA <- La.svd(A)
svdA$u %*% svdA$vt # = Q
}
# .projKron <- function(dims) {
# ... # TODO: Implement this!
# }
#' @rdname matProj
#' @export
.projMaskedMean <- function(mask) {
force(mask)
function(A) {
`[<-`(matrix(0, nrow(A), ncol(A)), mask, mean(A[mask]))
}
}
#' Projections onto matrix manifolds
#'
#' @examples
#' p <- 5
#' q <- 4
#' A <- matrix(rnorm(p * q), p, q)
#'
#' # General Matrices
#' matProj("TriDiag", dim(A))(A)
#' matProj("Band", dim(A), low = 1, high = 2)(A)
#' matProj("Rank", rank = 2)(A)
#' matProj("Stiefel")(A)
#'
#' # Symmetric projections need square matrices
#' S <- matrix(rnorm(p^2), p)
#'
#' matProj("Sym")(S)
#' matProj("SymTriDiag", dim(S))(S)
#' matProj("SymBand", dim(S), low = 1, high = 2)(S)
#' matProj("PSD")(S)
#' matProj("SymRank", rank = 1)(S)
#'
#' @rdname matProj
#'
#' @export
matProj <- function(manifold, dims = NULL, low = NULL, high = NULL, sym = FALSE, rank = NULL) {
switch(tolower(manifold),
identity = identity,
sym = projSym,
tridiag = .projBand(dims, 1L, 1L),
symtridiag = .projSymBand(dims, 1L, 1L),
band = .projBand(dims, low, high),
symband = .projSymBand(dims, low, high),
psd = .projPSD(sym),
rank = .projRank(rank),
symrank = .projSymRank(rank),
stiefel = projStiefel
)
}
# #' Basis of ....
# mat.proj.basis <- function(manifold, dims = NULL, low = NULL, high = NULL, sym = FALSE, rank = NULL) ...