212 lines
7.5 KiB
R
212 lines
7.5 KiB
R
|
#' Gradient Descent Bases Tensor Predictors method with Nesterov Accelerated
|
||
|
#' Momentum and Kronecker structure assumption for the residual covariance
|
||
|
#' `Delta = Delta.1 %x% Delta.2` (simple plugin version!)
|
||
|
#'
|
||
|
#' @export
|
||
|
kpir.kron <- function(X, Fy, shape = c(dim(X)[-1], dim(Fy[-1])),
|
||
|
max.iter = 500L, max.line.iter = 50L, step.size = 1e-3,
|
||
|
nesterov.scaling = function(a, t) { 0.5 * (1 + sqrt(1 + (2 * a)^2)) },
|
||
|
eps = .Machine$double.eps,
|
||
|
logger = NULL
|
||
|
) {
|
||
|
|
||
|
# Check if X and Fy have same number of observations
|
||
|
stopifnot(nrow(X) == NROW(Fy))
|
||
|
n <- nrow(X) # Number of observations
|
||
|
|
||
|
# Get and check predictor dimensions (convert to 3-tensor if needed)
|
||
|
if (length(dim(X)) == 2L) {
|
||
|
stopifnot(!missing(shape))
|
||
|
stopifnot(ncol(X) == prod(shape[1:2]))
|
||
|
p <- as.integer(shape[1]) # Predictor "height"
|
||
|
q <- as.integer(shape[2]) # Predictor "width"
|
||
|
} else if (length(dim(X)) == 3L) {
|
||
|
p <- dim(X)[2]
|
||
|
q <- dim(X)[3]
|
||
|
} else {
|
||
|
stop("'X' must be a matrix or 3-tensor")
|
||
|
}
|
||
|
|
||
|
# Get and check response dimensions (and convert to 3-tensor if needed)
|
||
|
if (!is.array(Fy)) {
|
||
|
Fy <- as.array(Fy)
|
||
|
}
|
||
|
if (length(dim(Fy)) == 1L) {
|
||
|
k <- r <- 1L
|
||
|
dim(Fy) <- c(n, 1L, 1L)
|
||
|
} else if (length(dim(Fy)) == 2L) {
|
||
|
stopifnot(!missing(shape))
|
||
|
stopifnot(ncol(Fy) == prod(shape[3:4]))
|
||
|
k <- as.integer(shape[3]) # Response functional "height"
|
||
|
r <- as.integer(shape[4]) # Response functional "width"
|
||
|
} else if (length(dim(Fy)) == 3L) {
|
||
|
k <- dim(Fy)[2]
|
||
|
r <- dim(Fy)[3]
|
||
|
} else {
|
||
|
stop("'Fy' must be a vector, matrix or 3-tensor")
|
||
|
}
|
||
|
|
||
|
|
||
|
### Step 1: (Approx) Least Squares solution for `X = Fy B' + epsilon`
|
||
|
# Vectorize
|
||
|
dim(Fy) <- c(n, k * r)
|
||
|
dim(X) <- c(n, p * q)
|
||
|
# Solve
|
||
|
cpFy <- crossprod(Fy) # TODO: Check/Test and/or replace
|
||
|
if (n <= k * r || qr(cpFy)$rank < k * r) {
|
||
|
# In case of under-determined system replace the inverse in the normal
|
||
|
# equation by the Moore-Penrose Pseudo Inverse
|
||
|
B <- t(matpow(cpFy, -1) %*% crossprod(Fy, X))
|
||
|
} else {
|
||
|
# Compute OLS estimate by the Normal Equation
|
||
|
B <- t(solve(cpFy, crossprod(Fy, X)))
|
||
|
}
|
||
|
|
||
|
# De-Vectroize (from now on tensor arithmetics)
|
||
|
dim(Fy) <- c(n, k, r)
|
||
|
dim(X) <- c(n, p, q)
|
||
|
|
||
|
# Decompose `B = alpha x beta` into `alpha` and `beta`
|
||
|
c(alpha0, beta0) %<-% approx.kronecker(B, c(q, r), c(p, k))
|
||
|
|
||
|
# Compute residuals
|
||
|
resid <- X - (Fy %x_3% alpha0 %x_2% beta0)
|
||
|
|
||
|
# Covariance estimate
|
||
|
Delta.1 <- tcrossprod(mat(resid, 3))
|
||
|
Delta.2 <- tcrossprod(mat(resid, 2))
|
||
|
tr <- sum(diag(Delta.1))
|
||
|
Delta.1 <- Delta.1 / sqrt(n * tr)
|
||
|
Delta.2 <- Delta.2 / sqrt(n * tr)
|
||
|
|
||
|
# Transformed Residuals
|
||
|
resid.trans <- resid %x_3% solve(Delta.1) %x_2% solve(Delta.2)
|
||
|
|
||
|
# Evaluate negative log-likelihood
|
||
|
loss <- 0.5 * (n * (p * log(det(Delta.1)) + q * log(det(Delta.2))) +
|
||
|
sum(resid.trans * resid))
|
||
|
|
||
|
# Call history callback (logger) before the first iterate
|
||
|
if (is.function(logger)) {
|
||
|
logger(0L, loss, alpha0, beta0, Delta.1, Delta.2, NA)
|
||
|
}
|
||
|
|
||
|
|
||
|
### Step 2: MLE with LS solution as starting value
|
||
|
a0 <- 0
|
||
|
a1 <- 1
|
||
|
alpha1 <- alpha0
|
||
|
beta1 <- beta0
|
||
|
|
||
|
# main descent loop
|
||
|
no.nesterov <- TRUE
|
||
|
break.reason <- NA
|
||
|
for (iter in seq_len(max.iter)) {
|
||
|
if (no.nesterov) {
|
||
|
# without extrapolation as fallback
|
||
|
S.alpha <- alpha1
|
||
|
S.beta <- beta1
|
||
|
} else {
|
||
|
# extrapolation using previous direction
|
||
|
S.alpha <- alpha1 + ((a0 - 1) / a1) * (alpha1 - alpha0)
|
||
|
S.beta <- beta1 + ((a0 - 1) / a1) * ( beta1 - beta0)
|
||
|
}
|
||
|
|
||
|
# Extrapolated Residuals
|
||
|
resid <- X - (Fy %x_3% S.alpha %x_2% S.beta)
|
||
|
|
||
|
# Covariance Estimates
|
||
|
Delta.1 <- tcrossprod(mat(resid, 3))
|
||
|
Delta.2 <- tcrossprod(mat(resid, 2))
|
||
|
tr <- sum(diag(Delta.1))
|
||
|
Delta.1 <- Delta.1 / sqrt(n * tr)
|
||
|
Delta.2 <- Delta.2 / sqrt(n * tr)
|
||
|
|
||
|
# Transform Residuals
|
||
|
resid.trans <- resid %x_3% solve(Delta.1) %x_2% solve(Delta.2)
|
||
|
|
||
|
# Calculate Gradients
|
||
|
grad.alpha <- tcrossprod(mat(resid.trans, 3), mat(Fy %x_2% beta, 3))
|
||
|
grad.beta <- tcrossprod(mat(resid.trans, 2), mat(Fy %x_3% alpha, 2))
|
||
|
|
||
|
# Backtracking line search (Armijo type)
|
||
|
# The `inner.prod` is used in the Armijo break condition but does not
|
||
|
# depend on the step size.
|
||
|
inner.prod <- sum(grad.alpha^2) + sum(grad.beta^2)
|
||
|
|
||
|
# backtracking loop
|
||
|
for (delta in step.size * 0.618034^seq.int(0L, len = max.line.iter)) {
|
||
|
# Update `alpha` and `beta` (note: add(+), the gradients are already
|
||
|
# pointing into the negative slope direction of the loss cause they are
|
||
|
# the gradients of the log-likelihood [NOT the negative log-likelihood])
|
||
|
alpha.temp <- S.alpha + delta * grad.alpha
|
||
|
beta.temp <- S.beta + delta * grad.beta
|
||
|
|
||
|
# Update Residuals, Covariance and transformed Residuals
|
||
|
resid <- X - (Fy %x_3% alpha.temp %x_2% beta.temp)
|
||
|
Delta.1 <- tcrossprod(mat(resid, 3))
|
||
|
Delta.2 <- tcrossprod(mat(resid, 2))
|
||
|
tr <- sum(diag(Delta.1))
|
||
|
Delta.1 <- Delta.1 / sqrt(n * tr)
|
||
|
Delta.2 <- Delta.2 / sqrt(n * tr)
|
||
|
resid.trans <- resid %x_3% solve(Delta.1) %x_2% solve(Delta.2)
|
||
|
|
||
|
# Evaluate negative log-likelihood
|
||
|
loss.temp <- 0.5 * (n * (p * log(det(Delta.1)) + q * log(det(Delta.2)))
|
||
|
+ sum(resid.trans * resid))
|
||
|
|
||
|
# Armijo line search break condition
|
||
|
if (loss.temp <= loss - 0.1 * delta * inner.prod) {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
# Call logger (invoce history callback)
|
||
|
if (is.function(logger)) {
|
||
|
logger(iter, loss.temp, alpha.temp, beta.temp, Delta.1, Delta.2, delta)
|
||
|
}
|
||
|
|
||
|
# Ensure descent
|
||
|
if (loss.temp < loss) {
|
||
|
alpha0 <- alpha1
|
||
|
alpha1 <- alpha.temp
|
||
|
beta0 <- beta1
|
||
|
beta1 <- beta.temp
|
||
|
|
||
|
# check break conditions (in descent case)
|
||
|
if (mean(abs(alpha1)) + mean(abs(beta1)) < eps) {
|
||
|
break.reason <- "alpha, beta numerically zero"
|
||
|
break # basically, estimates are zero -> stop
|
||
|
}
|
||
|
if (inner.prod < eps * (p * q + r * k)) {
|
||
|
break.reason <- "mean squared gradient is smaller than epsilon"
|
||
|
break # mean squared gradient is smaller than epsilon -> stop
|
||
|
}
|
||
|
if (abs(loss.temp - loss) < eps) {
|
||
|
break.reason <- "decrease is too small (slow)"
|
||
|
break # decrease is too small (slow) -> stop
|
||
|
}
|
||
|
|
||
|
loss <- loss.temp
|
||
|
no.nesterov <- FALSE # always reset
|
||
|
} else if (!no.nesterov) {
|
||
|
no.nesterov <- TRUE # retry without momentum
|
||
|
next
|
||
|
} else {
|
||
|
break.reason <- "failed even without momentum"
|
||
|
break # failed even without momentum -> stop
|
||
|
}
|
||
|
|
||
|
# update momentum scaling
|
||
|
a0 <- a1
|
||
|
a1 <- nesterov.scaling(a1, iter)
|
||
|
|
||
|
# Set next iter starting step.size to line searched step size
|
||
|
# (while allowing it to encrease)
|
||
|
step.size <- 1.618034 * delta
|
||
|
|
||
|
}
|
||
|
|
||
|
list(loss = loss, alpha = alpha1, beta = beta1, Delta = Delta, break.reason = break.reason)
|
||
|
}
|