From 74f1ce6ecfa8ac45d0d81f7b3f4659f05d18e6ec Mon Sep 17 00:00:00 2001 From: daniel Date: Tue, 14 Dec 2021 20:05:47 +0100 Subject: [PATCH] wip: RMReg --- tensorPredictors/R/RMReg.R | 44 ++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/tensorPredictors/R/RMReg.R b/tensorPredictors/R/RMReg.R index 0429518..d1b40ee 100644 --- a/tensorPredictors/R/RMReg.R +++ b/tensorPredictors/R/RMReg.R @@ -12,6 +12,9 @@ #' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|} #' corresponds to the nuclear norm regularization problem. #' +#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed. +#' In this case the return value is only estimate as a single value. +#' #' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a #' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the #' first dimension are the observations while the second and third are the @@ -20,8 +23,7 @@ #' @param Z additional covariate vector (can be \code{NULL} if not required. #' For regression with intercept set \code{Z = rep(1, n)}) #' @param y univariate response vector -#' @param lambda penalty term (note: range between 0 and max. signular value -#' of the least squares solution gives non-trivial results) +#' @param lambda penalty term, if set to \code{Inf} #' @param loss loss function, part of the objective function #' @param grad.loss gradient with respect to \eqn{B} of the loss function #' (required, there is no support for numerical gradients) @@ -68,21 +70,30 @@ RMReg <- function(X, Z, y, lambda = 0, } ### Set initial values - # singular values of B1 (require only current point, not previous B0) + # Note: Naming convention; a name ending with 1 is the current iterate while + # names ending with 0 are the previous iterate value. + # Init singular values of B1 (require only current point, not previous B0) if (missing(B0)) { b1 <- rep(0, min(shape)) } else { b1 <- La.svd(B0, 0, 0)$d } + # Init current to previous (start position) B1 <- B0 a0 <- 0 a1 <- 1 loss1 <- loss(B1, beta, X, Z, y) + # Start without, the nesterov momentum is zero anyway + no.nesterov <- TRUE ### Repeat untill convergence for (iter in 1:max.iter) { - # Extrapolation (Nesterov Momentum) - S <- B1 + ((a0 - 1) / a1) * (B1 - B0) + # Extrapolation with Nesterov Momentum + if (no.nesterov) { + S <- B1 + } else { + S <- B1 + ((a0 - 1) / a1) * (B1 - B0) + } # Solve for beta at extrapolation point if (!is.null(ZZiZ)) { @@ -97,12 +108,18 @@ RMReg <- function(X, Z, y, lambda = 0, # (potential) next step with delta as stepsize for gradient update A <- S - delta * grad - if (lambda > 0) { + if (lambda == Inf) { + # Application of Corollary 1 (only nuclear norm supported) to + # estimate maximum lambda. In this case (first time this line is + # hit when lambda set to Inf, then B is zero (ignore B0 param)) + lambda.max <- max(La.svd(A, 0, 0)$d) / delta + return(lambda.max) + } else if (lambda > 0) { # SVD of (potential) next step svdA <- La.svd(A) # Get (potential) next penalized iterate (nuclear norm version only) - b.temp <- pmax(0, svdA$d - lambda) # Singular values of B.temp + b.temp <- pmax(0, svdA$d - delta * lambda) # Singular values of B.temp B.temp <- svdA$u %*% (b.temp * svdA$vt) } else { # in case of no penalization (pure least squares solution) @@ -131,16 +148,24 @@ RMReg <- function(X, Z, y, lambda = 0, # After gradient update enforce descent (stop if not decreasing) loss.temp <- loss(B.temp, beta, X, Z, y) if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { + no.nesterov <- FALSE loss1 <- loss.temp B0 <- B1 B1 <- B.temp b1 <- b.temp + } else if (!no.nesterov) { + # Retry without Nesterov extrapolation + no.nesterov <- TRUE + next } else { break } - # Stop if estimate is zero - if (all(b1 < .Machine$double.eps)) { + # If estimate is zero, stop algorithm + if (all(b.temp < .Machine$double.eps)) { + loss1 <- loss.temp + B1 <- array(0, dim = shape) + b1 <- rep(0, min(shape)) break } @@ -166,6 +191,7 @@ RMReg <- function(X, Z, y, lambda = 0, iter = iter, df = df, loss = loss1, + lambda = lambda, AIC = loss1 / var(y) + 2 * df, BIC = loss1 / var(y) + log(nrow(X)) * df, call = match.call() # invocing function call, collects params like lambda