2019-08-30 19:16:52 +00:00
|
|
|
#' Simple implementation of the CVE method. 'Simple' means that this method is
|
|
|
|
#' a classic GD method unsing no further tricks.
|
|
|
|
#'
|
|
|
|
#' @keywords internal
|
|
|
|
#' @export
|
|
|
|
cve_simple <- function(X, Y, k,
|
|
|
|
nObs = sqrt(nrow(X)),
|
|
|
|
h = NULL,
|
|
|
|
tau = 1.0,
|
|
|
|
tol = 1e-3,
|
|
|
|
slack = 0,
|
|
|
|
epochs = 50L,
|
|
|
|
attempts = 10L
|
|
|
|
) {
|
2019-09-02 13:22:35 +00:00
|
|
|
# Set `grad` functions environment to enable if to find this environments
|
|
|
|
# local variabels, needed to enable the manipulation of this local variables
|
|
|
|
# from within `grad`.
|
|
|
|
environment(grad) <- environment()
|
|
|
|
|
|
|
|
# Setup loss histroy.
|
|
|
|
loss.history <- matrix(NA, epochs, attempts);
|
2019-08-30 19:16:52 +00:00
|
|
|
|
|
|
|
# Get dimensions.
|
|
|
|
n <- nrow(X)
|
|
|
|
p <- ncol(X)
|
|
|
|
q <- p - k
|
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Save initial learning rate `tau`.
|
|
|
|
tau.init <- tau
|
|
|
|
# Addapt tolearance for break condition.
|
|
|
|
tol <- sqrt(2 * q) * tol
|
|
|
|
|
2019-08-30 19:16:52 +00:00
|
|
|
# Estaimate bandwidth if not given.
|
|
|
|
if (missing(h) | !is.numeric(h)) {
|
|
|
|
h <- estimate.bandwidth(X, k, nObs)
|
|
|
|
}
|
|
|
|
|
|
|
|
I_p <- diag(1, p)
|
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Init tracking of current best (according multiple attempts).
|
2019-08-30 19:16:52 +00:00
|
|
|
V.best <- NULL
|
2019-09-02 13:22:35 +00:00
|
|
|
loss.best <- Inf
|
2019-08-30 19:16:52 +00:00
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Start loop for multiple attempts.
|
2019-08-30 19:16:52 +00:00
|
|
|
for (attempt in 1:attempts) {
|
|
|
|
|
|
|
|
# reset step width `tau`
|
|
|
|
tau <- tau.init
|
|
|
|
|
|
|
|
# Sample a `(p, q)` dimensional matrix from the stiefel manifold as
|
|
|
|
# optimization start value.
|
|
|
|
V <- rStiefl(p, q)
|
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Initial loss and gradient.
|
|
|
|
loss <- Inf
|
|
|
|
G <- grad(X, Y, V, h, loss.out = TRUE) # `loss.out=T` sets `loss`!
|
|
|
|
# Set last loss (aka, loss after applying the step).
|
|
|
|
loss.last <- loss
|
2019-08-30 19:16:52 +00:00
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Cayley transform matrix `A`
|
|
|
|
A <- (G %*% t(V)) - (V %*% t(G))
|
2019-08-30 19:16:52 +00:00
|
|
|
|
|
|
|
## Start optimization loop.
|
2019-09-02 13:22:35 +00:00
|
|
|
for (epoch in 1:epochs) {
|
|
|
|
# Apply learning rate `tau`.
|
2019-08-30 19:16:52 +00:00
|
|
|
A.tau <- tau * A
|
2019-09-02 13:22:35 +00:00
|
|
|
# Parallet transport (on Stiefl manifold) into direction of `G`.
|
2019-08-30 19:16:52 +00:00
|
|
|
V.tau <- solve(I_p + A.tau) %*% ((I_p - A.tau) %*% V)
|
|
|
|
|
2019-09-02 13:22:35 +00:00
|
|
|
# Loss at position after a step.
|
|
|
|
loss <- grad(X, Y, V.tau, h, loss.only = TRUE)
|
2019-08-30 19:16:52 +00:00
|
|
|
|
|
|
|
# Check if step is appropriate
|
2019-09-02 13:22:35 +00:00
|
|
|
if ((loss - loss.last) > slack * loss.last) {
|
2019-08-30 19:16:52 +00:00
|
|
|
tau <- tau / 2
|
2019-09-02 13:22:35 +00:00
|
|
|
next() # Keep position and try with smaller `tau`.
|
|
|
|
}
|
|
|
|
|
|
|
|
# Compute error.
|
|
|
|
error <- norm(V %*% t(V) - V.tau %*% t(V.tau), type = "F")
|
|
|
|
# Check break condition (epoch check to skip ignored gradient calc).
|
|
|
|
# Note: the devision by `sqrt(2 * k)` is included in `tol`.
|
|
|
|
if (error < tol | epoch >= epochs) {
|
|
|
|
# take last step and stop optimization.
|
2019-08-30 19:16:52 +00:00
|
|
|
V <- V.tau
|
2019-09-02 13:22:35 +00:00
|
|
|
break()
|
2019-08-30 19:16:52 +00:00
|
|
|
}
|
2019-09-02 13:22:35 +00:00
|
|
|
|
|
|
|
# Perform the step and remember previous loss.
|
|
|
|
V <- V.tau
|
|
|
|
loss.last <- loss
|
|
|
|
|
|
|
|
# Compute gradient at new position.
|
|
|
|
# Note: `loss` will be updated too!
|
|
|
|
G <- grad(X, Y, V, h, loss.out = TRUE, loss.log = TRUE)
|
|
|
|
|
|
|
|
# Cayley transform matrix `A`
|
|
|
|
A <- (G %*% t(V)) - (V %*% t(G))
|
2019-08-30 19:16:52 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
# Check if current attempt improved previous ones
|
2019-09-02 13:22:35 +00:00
|
|
|
if (loss < loss.best) {
|
|
|
|
loss.best <- loss
|
|
|
|
V.best <- V
|
2019-08-30 19:16:52 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return(list(
|
2019-09-02 13:22:35 +00:00
|
|
|
loss.history = loss.history,
|
2019-08-30 19:16:52 +00:00
|
|
|
loss = loss.best,
|
|
|
|
V = V.best,
|
|
|
|
B = null(V.best),
|
|
|
|
h = h
|
|
|
|
))
|
|
|
|
}
|