121 lines
3.1 KiB
R
121 lines
3.1 KiB
R
#' @rdname matProj
|
|
#' @export
|
|
projSym <- function(A) 0.5 * (A + t(A))
|
|
#' @rdname matProj
|
|
#' @export
|
|
projDiag <- function(A) diag(diag(A))
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projBand <- function(dims, low, high) {
|
|
diag.index <- .row(dims) - .col(dims)
|
|
mask <- (diag.index <= low) & (-high <= diag.index)
|
|
function(A) A * mask
|
|
}
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projSymBand <- function(dims, low, high) {
|
|
diag.index <- .row(dims) - .col(dims)
|
|
mask <- (diag.index <= low) & (-high <= diag.index)
|
|
function(A) projSym(A) * mask
|
|
}
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projPSD <- function(sym = FALSE) {
|
|
if (sym) {
|
|
function(A) {
|
|
eig <- eigen(A, symmetric = TRUE)
|
|
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
|
|
}
|
|
} else {
|
|
function(A) {
|
|
eig <- eigen(0.5 * (A + t(A)), symmetric = TRUE)
|
|
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
|
|
}
|
|
}
|
|
}
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projRank <- function(rank) {
|
|
force(rank)
|
|
function(A) {
|
|
rank <- min(dim(A), rank)
|
|
svdA <- La.svd(A, rank, rank)
|
|
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
|
|
}
|
|
}
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projSymRank <- function(rank) {
|
|
force(rank)
|
|
function(A) {
|
|
rank <- min(dim(A), rank)
|
|
svdA <- La.svd(0.5 * (A + t(A)), rank, rank)
|
|
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
|
|
}
|
|
}
|
|
#' @rdname matProj
|
|
#' @export
|
|
projStiefel <- function(A) {
|
|
# Using a polar decomposition of `A = Q P` via SVD `A = U D V^T`. Compaired
|
|
# to a QR decomposition the polar decomposition is unique, making it "stabel".
|
|
svdA <- La.svd(A)
|
|
svdA$u %*% svdA$vt # = Q
|
|
}
|
|
# .projKron <- function(dims) {
|
|
# ... # TODO: Implement this!
|
|
# }
|
|
|
|
#' @rdname matProj
|
|
#' @export
|
|
.projMaskedMean <- function(mask) {
|
|
force(mask)
|
|
function(A) {
|
|
`[<-`(matrix(0, nrow(A), ncol(A)), mask, mean(A[mask]))
|
|
}
|
|
}
|
|
|
|
#' Projections onto matrix manifolds
|
|
#'
|
|
#' @examples
|
|
#' p <- 5
|
|
#' q <- 4
|
|
#' A <- matrix(rnorm(p * q), p, q)
|
|
#'
|
|
#' # General Matrices
|
|
#' matProj("TriDiag", dim(A))(A)
|
|
#' matProj("Band", dim(A), low = 1, high = 2)(A)
|
|
#' matProj("Rank", rank = 2)(A)
|
|
#' matProj("Stiefel")(A)
|
|
#'
|
|
#' # Symmetric projections need square matrices
|
|
#' S <- matrix(rnorm(p^2), p)
|
|
#'
|
|
#' matProj("Sym")(S)
|
|
#' matProj("SymTriDiag", dim(S))(S)
|
|
#' matProj("SymBand", dim(S), low = 1, high = 2)(S)
|
|
#' matProj("PSD")(S)
|
|
#' matProj("SymRank", rank = 1)(S)
|
|
#'
|
|
#' @rdname matProj
|
|
#'
|
|
#' @export
|
|
matProj <- function(manifold, dims = NULL, low = NULL, high = NULL, sym = FALSE, rank = NULL) {
|
|
|
|
switch(tolower(manifold),
|
|
identity = identity,
|
|
sym = projSym,
|
|
tridiag = .projBand(dims, 1L, 1L),
|
|
symtridiag = .projSymBand(dims, 1L, 1L),
|
|
band = .projBand(dims, low, high),
|
|
symband = .projSymBand(dims, low, high),
|
|
psd = .projPSD(sym),
|
|
rank = .projRank(rank),
|
|
symrank = .projSymRank(rank),
|
|
stiefel = projStiefel
|
|
)
|
|
|
|
}
|
|
|
|
# #' Basis of ....
|
|
# mat.proj.basis <- function(manifold, dims = NULL, low = NULL, high = NULL, sym = FALSE, rank = NULL) ...
|