wip: RMReg

This commit is contained in:
Daniel Kapla 2021-12-14 20:05:47 +01:00
parent 1531da7f19
commit 74f1ce6ecf
1 changed files with 35 additions and 9 deletions

View File

@ -12,6 +12,9 @@
#' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|} #' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|}
#' corresponds to the nuclear norm regularization problem. #' corresponds to the nuclear norm regularization problem.
#' #'
#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed.
#' In this case the return value is only estimate as a single value.
#'
#' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a #' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a
#' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the #' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the
#' first dimension are the observations while the second and third are the #' first dimension are the observations while the second and third are the
@ -20,8 +23,7 @@
#' @param Z additional covariate vector (can be \code{NULL} if not required. #' @param Z additional covariate vector (can be \code{NULL} if not required.
#' For regression with intercept set \code{Z = rep(1, n)}) #' For regression with intercept set \code{Z = rep(1, n)})
#' @param y univariate response vector #' @param y univariate response vector
#' @param lambda penalty term (note: range between 0 and max. signular value #' @param lambda penalty term, if set to \code{Inf}
#' of the least squares solution gives non-trivial results)
#' @param loss loss function, part of the objective function #' @param loss loss function, part of the objective function
#' @param grad.loss gradient with respect to \eqn{B} of the loss function #' @param grad.loss gradient with respect to \eqn{B} of the loss function
#' (required, there is no support for numerical gradients) #' (required, there is no support for numerical gradients)
@ -68,21 +70,30 @@ RMReg <- function(X, Z, y, lambda = 0,
} }
### Set initial values ### Set initial values
# singular values of B1 (require only current point, not previous B0) # Note: Naming convention; a name ending with 1 is the current iterate while
# names ending with 0 are the previous iterate value.
# Init singular values of B1 (require only current point, not previous B0)
if (missing(B0)) { if (missing(B0)) {
b1 <- rep(0, min(shape)) b1 <- rep(0, min(shape))
} else { } else {
b1 <- La.svd(B0, 0, 0)$d b1 <- La.svd(B0, 0, 0)$d
} }
# Init current to previous (start position)
B1 <- B0 B1 <- B0
a0 <- 0 a0 <- 0
a1 <- 1 a1 <- 1
loss1 <- loss(B1, beta, X, Z, y) loss1 <- loss(B1, beta, X, Z, y)
# Start without, the nesterov momentum is zero anyway
no.nesterov <- TRUE
### Repeat untill convergence ### Repeat untill convergence
for (iter in 1:max.iter) { for (iter in 1:max.iter) {
# Extrapolation (Nesterov Momentum) # Extrapolation with Nesterov Momentum
S <- B1 + ((a0 - 1) / a1) * (B1 - B0) if (no.nesterov) {
S <- B1
} else {
S <- B1 + ((a0 - 1) / a1) * (B1 - B0)
}
# Solve for beta at extrapolation point # Solve for beta at extrapolation point
if (!is.null(ZZiZ)) { if (!is.null(ZZiZ)) {
@ -97,12 +108,18 @@ RMReg <- function(X, Z, y, lambda = 0,
# (potential) next step with delta as stepsize for gradient update # (potential) next step with delta as stepsize for gradient update
A <- S - delta * grad A <- S - delta * grad
if (lambda > 0) { if (lambda == Inf) {
# Application of Corollary 1 (only nuclear norm supported) to
# estimate maximum lambda. In this case (first time this line is
# hit when lambda set to Inf, then B is zero (ignore B0 param))
lambda.max <- max(La.svd(A, 0, 0)$d) / delta
return(lambda.max)
} else if (lambda > 0) {
# SVD of (potential) next step # SVD of (potential) next step
svdA <- La.svd(A) svdA <- La.svd(A)
# Get (potential) next penalized iterate (nuclear norm version only) # Get (potential) next penalized iterate (nuclear norm version only)
b.temp <- pmax(0, svdA$d - lambda) # Singular values of B.temp b.temp <- pmax(0, svdA$d - delta * lambda) # Singular values of B.temp
B.temp <- svdA$u %*% (b.temp * svdA$vt) B.temp <- svdA$u %*% (b.temp * svdA$vt)
} else { } else {
# in case of no penalization (pure least squares solution) # in case of no penalization (pure least squares solution)
@ -131,16 +148,24 @@ RMReg <- function(X, Z, y, lambda = 0,
# After gradient update enforce descent (stop if not decreasing) # After gradient update enforce descent (stop if not decreasing)
loss.temp <- loss(B.temp, beta, X, Z, y) loss.temp <- loss(B.temp, beta, X, Z, y)
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
no.nesterov <- FALSE
loss1 <- loss.temp loss1 <- loss.temp
B0 <- B1 B0 <- B1
B1 <- B.temp B1 <- B.temp
b1 <- b.temp b1 <- b.temp
} else if (!no.nesterov) {
# Retry without Nesterov extrapolation
no.nesterov <- TRUE
next
} else { } else {
break break
} }
# Stop if estimate is zero # If estimate is zero, stop algorithm
if (all(b1 < .Machine$double.eps)) { if (all(b.temp < .Machine$double.eps)) {
loss1 <- loss.temp
B1 <- array(0, dim = shape)
b1 <- rep(0, min(shape))
break break
} }
@ -166,6 +191,7 @@ RMReg <- function(X, Z, y, lambda = 0,
iter = iter, iter = iter,
df = df, df = df,
loss = loss1, loss = loss1,
lambda = lambda,
AIC = loss1 / var(y) + 2 * df, AIC = loss1 / var(y) + 2 * df,
BIC = loss1 / var(y) + log(nrow(X)) * df, BIC = loss1 / var(y) + log(nrow(X)) * df,
call = match.call() # invocing function call, collects params like lambda call = match.call() # invocing function call, collects params like lambda