180 lines
5.6 KiB
R
180 lines
5.6 KiB
R
#' Implementation of the CVE method as a Riemann Conjugated Gradient method.
|
|
#'
|
|
#' @references A Riemannian Conjugate Gradient Algorithm with Implicit Vector
|
|
#' Transport for Optimization on the Stiefel Manifold
|
|
#' @keywords internal
|
|
#' @export
|
|
cve_rcg <- function(X, Y, k,
|
|
nObs = sqrt(nrow(X)),
|
|
h = NULL,
|
|
tau = 1.0,
|
|
tol = 1e-4,
|
|
rho = 1e-4, # For Armijo condition.
|
|
slack = 0,
|
|
epochs = 50L,
|
|
attempts = 10L,
|
|
max.linesearch.iter = 20L,
|
|
logger = NULL
|
|
) {
|
|
# Set `grad` functions environment to enable if to find this environments
|
|
# local variabels, needed to enable the manipulation of this local variables
|
|
# from within `grad`.
|
|
environment(grad) <- environment()
|
|
|
|
# Get dimensions.
|
|
n <- nrow(X) # Number of samples.
|
|
p <- ncol(X) # Data dimensions
|
|
q <- p - k # Complement dimension of the SDR space.
|
|
|
|
# Save initial learning rate `tau`.
|
|
tau.init <- tau
|
|
# Addapt tolearance for break condition.
|
|
tol <- sqrt(2 * q) * tol
|
|
|
|
# Estaimate bandwidth if not given.
|
|
if (missing(h) || !is.numeric(h)) {
|
|
h <- estimate.bandwidth(X, k, nObs)
|
|
}
|
|
|
|
# Compute persistent data.
|
|
# Compute lookup indexes for symmetrie, lower/upper
|
|
# triangular parts and vectorization.
|
|
pair.index <- elem.pairs(seq(n))
|
|
i <- pair.index[1, ] # `i` indices of `(i, j)` pairs
|
|
j <- pair.index[2, ] # `j` indices of `(i, j)` pairs
|
|
# Index of vectorized matrix, for lower and upper triangular part.
|
|
lower <- ((i - 1) * n) + j
|
|
upper <- ((j - 1) * n) + i
|
|
|
|
# Create all pairewise differences of rows of `X`.
|
|
X_diff <- X[i, , drop = F] - X[j, , drop = F]
|
|
# Identity matrix.
|
|
I_p <- diag(1, p)
|
|
|
|
# Init tracking of current best (according multiple attempts).
|
|
V.best <- NULL
|
|
loss.best <- Inf
|
|
|
|
# Start loop for multiple attempts.
|
|
for (attempt in 1:attempts) {
|
|
# Reset learning rate `tau`.
|
|
tau <- tau.init
|
|
|
|
# Sample a `(p, q)` dimensional matrix from the stiefel manifold as
|
|
# optimization start value.
|
|
V <- rStiefl(p, q)
|
|
|
|
# Initial loss and gradient.
|
|
loss <- Inf
|
|
G <- grad(X, Y, V, h, loss.out = TRUE, persistent = TRUE)
|
|
# Set last loss (aka, loss after applying the step).
|
|
loss.last <- loss
|
|
|
|
# Cayley transform matrix `A`
|
|
A <- (G %*% t(V)) - (V %*% t(G))
|
|
A.last <- A
|
|
|
|
W <- -A
|
|
Z <- W %*% V
|
|
|
|
# Compute directional derivative.
|
|
loss.prime <- sum(G * Z) # Tr(G^T Z)
|
|
|
|
# Call logger with initial values before starting optimization.
|
|
if (is.function(logger)) {
|
|
epoch <- 0 # Set epoch count to 0 (only relevant for logging).
|
|
error <- NA
|
|
logger(environment())
|
|
}
|
|
|
|
## Start optimization loop.
|
|
for (epoch in 1:epochs) {
|
|
# New directional derivative.
|
|
loss.prime <- sum(G * Z)
|
|
|
|
# Reset `tau` for step-size selection.
|
|
tau <- tau.init
|
|
for (iter in 1:max.linesearch.iter) {
|
|
V.tau <- retractStiefl(V + tau * Z)
|
|
# Loss at position after a step.
|
|
loss <- grad(X, Y, V.tau, h,
|
|
loss.only = TRUE, persistent = TRUE)
|
|
# Check Armijo condition.
|
|
if (loss <= loss.last + (rho * tau * loss.prime)) {
|
|
break() # Iff fulfilled stop linesearch.
|
|
}
|
|
# Reduce step-size and continue linesearch.
|
|
tau <- tau / 2
|
|
}
|
|
|
|
# Compute error.
|
|
error <- norm(V %*% t(V) - V.tau %*% t(V.tau), type = "F")
|
|
|
|
# Perform step with found step-size
|
|
V <- V.tau
|
|
loss.last <- loss
|
|
|
|
# Call logger.
|
|
if (is.function(logger)) {
|
|
logger(environment())
|
|
}
|
|
|
|
# Check break condition.
|
|
# Note: the devision by `sqrt(2 * k)` is included in `tol`.
|
|
if (error < tol) {
|
|
break()
|
|
}
|
|
|
|
# Compute Gradient at new position.
|
|
G <- grad(X, Y, V, h, persistent = TRUE)
|
|
# Store last `A` for `beta` computation.
|
|
A.last <- A
|
|
# Cayley transform matrix `A`
|
|
A <- (G %*% t(V)) - (V %*% t(G))
|
|
|
|
# Check 2. break condition.
|
|
if (norm(A, type = 'F') < tol) {
|
|
break()
|
|
}
|
|
|
|
# New directional derivative.
|
|
loss.prime <- sum(G * Z)
|
|
|
|
# Reset beta if needed.
|
|
if (loss.prime < 0) {
|
|
# Compute `beta` as described in paper.
|
|
beta.FR <- (norm(A, type = 'F') / norm(A.last, type = 'F'))^2
|
|
beta.PR <- sum(A * (A - A.last)) / norm(A.last, type = 'F')^2
|
|
if (beta.PR < -beta.FR) {
|
|
beta <- -beta.FR
|
|
} else if (abs(beta.PR) < beta.FR) {
|
|
beta <- beta.PR
|
|
} else if (beta.PR > beta.FR) {
|
|
beta <- beta.FR
|
|
} else {
|
|
beta <- 0
|
|
}
|
|
} else {
|
|
beta <- 0
|
|
}
|
|
|
|
# Update direction.
|
|
W <- -A + beta * W
|
|
Z <- W %*% V
|
|
}
|
|
|
|
# Check if current attempt improved previous ones
|
|
if (loss < loss.best) {
|
|
loss.best <- loss
|
|
V.best <- V
|
|
}
|
|
}
|
|
|
|
return(list(
|
|
loss = loss.best,
|
|
V = V.best,
|
|
B = null(V.best),
|
|
h = h
|
|
))
|
|
}
|