wip: RMReg
This commit is contained in:
parent
1531da7f19
commit
74f1ce6ecf
|
@ -12,6 +12,9 @@
|
||||||
#' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|}
|
#' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|}
|
||||||
#' corresponds to the nuclear norm regularization problem.
|
#' corresponds to the nuclear norm regularization problem.
|
||||||
#'
|
#'
|
||||||
|
#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed.
|
||||||
|
#' In this case the return value is only estimate as a single value.
|
||||||
|
#'
|
||||||
#' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a
|
#' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a
|
||||||
#' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the
|
#' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the
|
||||||
#' first dimension are the observations while the second and third are the
|
#' first dimension are the observations while the second and third are the
|
||||||
|
@ -20,8 +23,7 @@
|
||||||
#' @param Z additional covariate vector (can be \code{NULL} if not required.
|
#' @param Z additional covariate vector (can be \code{NULL} if not required.
|
||||||
#' For regression with intercept set \code{Z = rep(1, n)})
|
#' For regression with intercept set \code{Z = rep(1, n)})
|
||||||
#' @param y univariate response vector
|
#' @param y univariate response vector
|
||||||
#' @param lambda penalty term (note: range between 0 and max. signular value
|
#' @param lambda penalty term, if set to \code{Inf}
|
||||||
#' of the least squares solution gives non-trivial results)
|
|
||||||
#' @param loss loss function, part of the objective function
|
#' @param loss loss function, part of the objective function
|
||||||
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
|
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
|
||||||
#' (required, there is no support for numerical gradients)
|
#' (required, there is no support for numerical gradients)
|
||||||
|
@ -68,21 +70,30 @@ RMReg <- function(X, Z, y, lambda = 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
### Set initial values
|
### Set initial values
|
||||||
# singular values of B1 (require only current point, not previous B0)
|
# Note: Naming convention; a name ending with 1 is the current iterate while
|
||||||
|
# names ending with 0 are the previous iterate value.
|
||||||
|
# Init singular values of B1 (require only current point, not previous B0)
|
||||||
if (missing(B0)) {
|
if (missing(B0)) {
|
||||||
b1 <- rep(0, min(shape))
|
b1 <- rep(0, min(shape))
|
||||||
} else {
|
} else {
|
||||||
b1 <- La.svd(B0, 0, 0)$d
|
b1 <- La.svd(B0, 0, 0)$d
|
||||||
}
|
}
|
||||||
|
# Init current to previous (start position)
|
||||||
B1 <- B0
|
B1 <- B0
|
||||||
a0 <- 0
|
a0 <- 0
|
||||||
a1 <- 1
|
a1 <- 1
|
||||||
loss1 <- loss(B1, beta, X, Z, y)
|
loss1 <- loss(B1, beta, X, Z, y)
|
||||||
|
|
||||||
|
# Start without, the nesterov momentum is zero anyway
|
||||||
|
no.nesterov <- TRUE
|
||||||
### Repeat untill convergence
|
### Repeat untill convergence
|
||||||
for (iter in 1:max.iter) {
|
for (iter in 1:max.iter) {
|
||||||
# Extrapolation (Nesterov Momentum)
|
# Extrapolation with Nesterov Momentum
|
||||||
S <- B1 + ((a0 - 1) / a1) * (B1 - B0)
|
if (no.nesterov) {
|
||||||
|
S <- B1
|
||||||
|
} else {
|
||||||
|
S <- B1 + ((a0 - 1) / a1) * (B1 - B0)
|
||||||
|
}
|
||||||
|
|
||||||
# Solve for beta at extrapolation point
|
# Solve for beta at extrapolation point
|
||||||
if (!is.null(ZZiZ)) {
|
if (!is.null(ZZiZ)) {
|
||||||
|
@ -97,12 +108,18 @@ RMReg <- function(X, Z, y, lambda = 0,
|
||||||
# (potential) next step with delta as stepsize for gradient update
|
# (potential) next step with delta as stepsize for gradient update
|
||||||
A <- S - delta * grad
|
A <- S - delta * grad
|
||||||
|
|
||||||
if (lambda > 0) {
|
if (lambda == Inf) {
|
||||||
|
# Application of Corollary 1 (only nuclear norm supported) to
|
||||||
|
# estimate maximum lambda. In this case (first time this line is
|
||||||
|
# hit when lambda set to Inf, then B is zero (ignore B0 param))
|
||||||
|
lambda.max <- max(La.svd(A, 0, 0)$d) / delta
|
||||||
|
return(lambda.max)
|
||||||
|
} else if (lambda > 0) {
|
||||||
# SVD of (potential) next step
|
# SVD of (potential) next step
|
||||||
svdA <- La.svd(A)
|
svdA <- La.svd(A)
|
||||||
|
|
||||||
# Get (potential) next penalized iterate (nuclear norm version only)
|
# Get (potential) next penalized iterate (nuclear norm version only)
|
||||||
b.temp <- pmax(0, svdA$d - lambda) # Singular values of B.temp
|
b.temp <- pmax(0, svdA$d - delta * lambda) # Singular values of B.temp
|
||||||
B.temp <- svdA$u %*% (b.temp * svdA$vt)
|
B.temp <- svdA$u %*% (b.temp * svdA$vt)
|
||||||
} else {
|
} else {
|
||||||
# in case of no penalization (pure least squares solution)
|
# in case of no penalization (pure least squares solution)
|
||||||
|
@ -131,16 +148,24 @@ RMReg <- function(X, Z, y, lambda = 0,
|
||||||
# After gradient update enforce descent (stop if not decreasing)
|
# After gradient update enforce descent (stop if not decreasing)
|
||||||
loss.temp <- loss(B.temp, beta, X, Z, y)
|
loss.temp <- loss(B.temp, beta, X, Z, y)
|
||||||
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
|
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
|
||||||
|
no.nesterov <- FALSE
|
||||||
loss1 <- loss.temp
|
loss1 <- loss.temp
|
||||||
B0 <- B1
|
B0 <- B1
|
||||||
B1 <- B.temp
|
B1 <- B.temp
|
||||||
b1 <- b.temp
|
b1 <- b.temp
|
||||||
|
} else if (!no.nesterov) {
|
||||||
|
# Retry without Nesterov extrapolation
|
||||||
|
no.nesterov <- TRUE
|
||||||
|
next
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
# Stop if estimate is zero
|
# If estimate is zero, stop algorithm
|
||||||
if (all(b1 < .Machine$double.eps)) {
|
if (all(b.temp < .Machine$double.eps)) {
|
||||||
|
loss1 <- loss.temp
|
||||||
|
B1 <- array(0, dim = shape)
|
||||||
|
b1 <- rep(0, min(shape))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,6 +191,7 @@ RMReg <- function(X, Z, y, lambda = 0,
|
||||||
iter = iter,
|
iter = iter,
|
||||||
df = df,
|
df = df,
|
||||||
loss = loss1,
|
loss = loss1,
|
||||||
|
lambda = lambda,
|
||||||
AIC = loss1 / var(y) + 2 * df,
|
AIC = loss1 / var(y) + 2 * df,
|
||||||
BIC = loss1 / var(y) + log(nrow(X)) * df,
|
BIC = loss1 / var(y) + log(nrow(X)) * df,
|
||||||
call = match.call() # invocing function call, collects params like lambda
|
call = match.call() # invocing function call, collects params like lambda
|
||||||
|
|
Loading…
Reference in New Issue