fix: RMReg

This commit is contained in:
Daniel Kapla 2022-01-18 12:30:05 +01:00
parent 8da1950d02
commit 79f5e9781f
1 changed files with 99 additions and 106 deletions

View File

@ -7,10 +7,8 @@
#' log-likelihood as the loss. \eqn{J(B) = f(\sigma(B))}, where \eqn{f} is a #' log-likelihood as the loss. \eqn{J(B) = f(\sigma(B))}, where \eqn{f} is a
#' function of the singular values of \eqn{B}. #' function of the singular values of \eqn{B}.
#' #'
#' The default parameterization is a nuclear norm penalized least squares regression. #' Currently, only the least squares problem with nuclear norm penalty is
#' #' implemented.
#' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|}
#' corresponds to the nuclear norm regularization problem.
#' #'
#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed. #' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed.
#' In this case the return value is only estimate as a single value. #' In this case the return value is only estimate as a single value.
@ -24,39 +22,39 @@
#' For regression with intercept set \code{Z = rep(1, n)}) #' For regression with intercept set \code{Z = rep(1, n)})
#' @param y univariate response vector #' @param y univariate response vector
#' @param lambda penalty term, if set to \code{Inf} max lambda is computed. #' @param lambda penalty term, if set to \code{Inf} max lambda is computed.
#' @param loss loss function, part of the objective function #' @param max.iter maximum number of gadient updates
#' @param penalty penalty function with a vector of the singular values if the #' @param max.line.iter maximum number of line search iterations
#' current iterate as arguments. The default function
#' \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
#' (required, there is no support for numerical gradients)
#' @param shape Shape of the matrix valued predictors. Required iff the #' @param shape Shape of the matrix valued predictors. Required iff the
#' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix. #' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
#' @param step.size max. stepsize for gradient updates #' @param step.size max. stepsize for gradient updates
#' @param alpha iterative Nesterov momentum scaling values
#' @param B0 initial value for optimization. Matrix of dimensions \eqn{p\times q} #' @param B0 initial value for optimization. Matrix of dimensions \eqn{p\times q}
#' @param beta initial value of additional covatiates coefficient for \eqn{Z} #' @param beta0 initial value of additional covatiates coefficient for \eqn{Z}
#' @param max.iter maximum number of gadient updates #' @param alpha iterative Nesterov momentum scaling values
#' @param max.line.iter maximum number of line search iterations #' @param eps precition for main loop break conditions
#' @param logger logging callback invoced after every line search before break #' @param logger logging callback invoced after every line search before break
#' condition checks. The expected function signature is of the form #' condition checks. The expected function signature is of the form
#' \code{logger(iter, loss, penalty, grad, B, beta, step.size)}. #' \code{function(iter, loss, penalty, B, beta, step.size)}.
#' #'
#' @export #' @export
RMReg <- function(X, Z, y, lambda = 0, RMReg <- function(X, Z, y, lambda = 0, max.iter = 500L, max.line.iter = 50L,
loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2), shape = dim(X)[-1], step.size = 1e-3,
penalty = function(sigma) sum(sigma),
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
shape = dim(X)[-1],
step.size = 1e-3,
alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
B0 = array(0, dim = shape), B0 = array(0, dim = shape),
beta = rep(0, NCOL(Z)), beta0 = rep(0, NCOL(Z)),
max.iter = 500, alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)), eps = .Machine$double.eps,
logger = NULL logger = NULL
) { ) {
### Check (prepair) params # Define loss (without penalty)
loss <- function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2)
# gradient of loss (without penalty)
grad <- function(B, beta, X, Z, y) {
inner <- X %*% c(B) + Z %*% beta - y
list(beta = c(crossprod(inner, Z)), B = c(crossprod(inner, X)))
}
# # and the penalty function (as function of singular values)
# penalty <- function(sigma) sum(sigma)
# Check (prepair) params
stopifnot(nrow(X) == length(y)) stopifnot(nrow(X) == length(y))
if (!missing(shape)) { if (!missing(shape)) {
stopifnot(ncol(X) == prod(shape)) stopifnot(ncol(X) == prod(shape))
@ -66,130 +64,125 @@ RMReg <- function(X, Z, y, lambda = 0,
} }
if (missing(Z) || is.null(Z)) { if (missing(Z) || is.null(Z)) {
Z <- matrix(0, nrow(X), 1) Z <- matrix(0, nrow(X), 1)
ZZiZ <- NULL } else if (!is.matrix(Z)) {
} else { Z <- as.matrix(Z)
if (!is.matrix(Z)) Z <- as.matrix(Z)
# Compute (Z' Z)^{-1} Z used to solve for beta. This is constant
# throughout and the variable name stands for "((Z' Z) Inverse) Z"
ZZiZ <- solve(crossprod(Z, Z), t(Z))
} }
### Set initial values # Set singular values of start matrix predictor coefficients
# Note: Naming convention; a name ending with 1 is the current iterate while
# names ending with 0 are the previous iterate value.
# Init singular values of B1 (require only current point, not previous B0)
if (missing(B0)) { if (missing(B0)) {
b1 <- rep(0, min(shape)) B1.sv <- rep(0, min(shape))
} else { } else {
b1 <- La.svd(B0, 0, 0)$d B1.sv <- La.svd(B0, 0, 0)$d
} }
# Init current to previous (start position) # initialize current and previous coefficients (start position)
B1 <- B0 B1 <- B0
a0 <- 0 beta1 <- beta0
a1 <- 1 alpha0 <- 0
loss1 <- loss(B1, beta, X, Z, y) alpha1 <- 1
loss0 <- loss1 <- loss(B1, beta1, X, Z, y)
# Start without, the nesterov momentum is zero anyway # main descent loop
no.nesterov <- TRUE # Set to FALSE after the first iteration no.nesterov <- FALSE
### Repeat untill convergence for (iter in seq_len(max.iter)) {
for (iter in 1:max.iter) {
# Extrapolation with Nesterov Momentum
if (no.nesterov) { if (no.nesterov) {
# classic gradient step as fallback
S <- B1 S <- B1
s <- beta1
} else { } else {
S <- B1 + ((a0 - 1) / a1) * (B1 - B0) # momentum step (extrapolation using previous direction)
S <- B1 + ((alpha0 - 1) / alpha1) * (B1 - B0)
s <- beta1 + ((alpha0 - 1) / alpha1) * (beta1 - beta0)
} }
# Solve for beta at extrapolation point # compute (nesterov) gradient
if (!is.null(ZZiZ)) { G <- grad(S, s, X, Z, y)
beta <- ZZiZ %*% (y - X %*% c(S))
}
# Compute Nesterov Gradient of the Loss # backtracking line search (executed at least once)
grad <- array(grad.loss(S, beta, X, Z, y), dim = shape) for (delta in step.size * 0.5^seq(0, max.line.iter - 1L)) {
# Gradient step with step size delta
# Line Search (executed at least once) A <- S - delta * G$B
for (delta in step.size * 0.5^seq(0, max.line.iter - 1)) { beta.temp <- s - delta * G$beta
# (potential) next step with delta as stepsize for gradient update
A <- S - delta * grad
if (lambda == Inf) { if (lambda == Inf) {
# Application of Corollary 1 (only nuclear norm supported) to # Application of Corollary 1 for estimation of max lambda
# estimate maximum lambda. In this case (first time this line is # Return max lambda estimate
# hit when lambda set to Inf, then B is zero (ignore B0 param)) return(max(La.svd(A, 0, 0)$d) / delta)
lambda.max <- max(La.svd(A, 0, 0)$d) / delta
return(lambda.max)
} else if (lambda > 0) { } else if (lambda > 0) {
# SVD of (potential) next step # SVD of (potential) next step
svdA <- La.svd(A) svdA <- La.svd(A)
# Get (potential) next penalized iterate (nuclear norm version only) # Next (possible) penalized iterate
b.temp <- pmax(0, svdA$d - delta * lambda) # Singular values of B.temp B.temp.sv <- pmax(0, svdA$d - delta * lambda)
B.temp <- svdA$u %*% (b.temp * svdA$vt) B.temp <- svdA$u %*% (B.temp.sv * svdA$vt)
} else { } else {
# in case of no penalization (pure least squares solution) # in case of no penalization (pure least squares)
b.temp <- La.svd(A, 0, 0)$d B.temp.sv <- La.svd(A, 0, 0)$d
B.temp <- A B.temp <- A
} }
# Solve for beta at (potential) next step # Check line search condition
if (!is.null(ZZiZ)) {
beta <- ZZiZ %*% (y - X %*% c(B.temp))
}
# Check line search break condition
# h(B.temp) <= g(B.temp | S, delta) # h(B.temp) <= g(B.temp | S, delta)
# \_ left _/ \_____ right _____/ # \_ left _/ \_____ right _____/
# where g(B.temp | S, delta) is the first order approx. of the loss # where g(B.temp | S, delta) is the first order approx. of the loss
# l(S) + <grad l(S), B - S> + | B - S |_F^2 / 2 delta + J(B) # l(S) + <grad l(S), B - S> + | B - S |_F^2 / 2 delta + J(B)
left <- loss(B.temp, beta, X, Z, y) # + penalty(b.temp) left <- loss(B.temp, beta.temp, X, Z, y)
right <- loss(S, beta, X, Z, y) + sum(grad * (B1 - S)) + right <- loss(S, s, X, Z, y) +
norm(B1 - S, 'F')^2 / (2 * delta) # + penalty(b.temp) sum(G$B * (B.temp - S)) + sum(G$beta * (beta.temp - s)) +
(norm(B.temp - S, 'F')^2 + sum((beta.temp - s)^2)) / (2 * delta)
if (left <= right) { if (left <= right) {
break break
} }
} }
# Evaluate loss at (potential) new parameters # Evaluate loss to ensure descent after line search
loss.temp <- loss(B.temp, beta, X, Z, y) loss.temp <- loss(B.temp, beta.temp, X, Z, y)
# logging callback # logging callback
if (is.function(logger)) { if (is.function(logger)) {
logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta) logger(iter, loss.temp, lambda * sum(B.temp.sv),
B.temp, beta.temp, delta)
} }
# After gradient update enforce descent (stop if not decreasing) # after line search enforce descent
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { if (loss.temp + lambda * sum(B.temp.sv) <= loss1 + lambda * sum(B1.sv)) {
no.nesterov <- FALSE
loss1 <- loss.temp
B0 <- B1 B0 <- B1
B1 <- B.temp B1 <- array(B.temp, shape)
b1 <- b.temp B1.sv <- B.temp.sv
beta0 <- beta1
beta1 <- beta.temp
loss0 <- loss1
loss1 <- loss.temp
no.nesterov <- FALSE # always reset
} else if (!no.nesterov) { } else if (!no.nesterov) {
# Retry without Nesterov extrapolation no.nesterov <- TRUE # retry without momentum
no.nesterov <- TRUE
next next
} else { } else {
break break # failed even without momentum -> stop
} }
# If estimate is zero, stop algorithm # check break conditions
if (all(b.temp < .Machine$double.eps)) { if (sum(B1.sv) < eps) {
loss1 <- loss.temp break # estimate is (numerically) zero -> stop
B1 <- array(0, dim = shape) }
b1 <- rep(0, min(shape)) if ((sum(G$B^2) + sum(G$beta^2)) < eps * sum(unlist(Map(length, G)))) {
break break # mean squared gradient is smaller than epsilon -> stop
}
if (abs(loss0 - loss1) < eps) {
break # decrease is smaller than epsilon -> stop
} }
# Update momentum scaling # update momentum scaling
a0 <- a1 alpha0 <- alpha1
a1 <- alpha(a1, iter) alpha1 <- alpha(alpha1, iter)
# set step size to two times current delta
step.size <- 2 * delta
} }
### Degrees of Freedom estimate (TODO: this is like in `matrix_sparsereg.m`) # Degrees of Freedom estimate (TODO: this is like in `matrix_sparsereg.m`)
sigma <- c(La.svd(A, 0, 0)$d, rep(0, max(shape) - min(shape))) sigma <- c(La.svd(A, 0, 0)$d, rep(0, max(shape) - min(shape)))
df <- if (!is.null(ZZiZ)) { ncol(Z) } else { 0 } df <- length(beta1)
for (i in seq_len(sum(b1 > 0))) { for (i in seq_len(sum(B1.sv > 0))) {
df <- df + 1 + sigma[i] * (sigma[i] - delta * lambda) * ( df <- df + 1 + sigma[i] * (sigma[i] - delta * lambda) * (
sum(ifelse((1:shape[1]) != i, 1 / (sigma[i]^2 - sigma[1:shape[1]]^2), 0)) + sum(ifelse((1:shape[1]) != i, 1 / (sigma[i]^2 - sigma[1:shape[1]]^2), 0)) +
sum(ifelse((1:shape[2]) != i, 1 / (sigma[i]^2 - sigma[1:shape[2]]^2), 0))) sum(ifelse((1:shape[2]) != i, 1 / (sigma[i]^2 - sigma[1:shape[2]]^2), 0)))
@ -198,14 +191,14 @@ RMReg <- function(X, Z, y, lambda = 0,
# return estimates and some additional stats # return estimates and some additional stats
list( list(
B = B1, B = B1,
beta = if(is.null(ZZiZ)) { NULL } else { beta }, beta = beta1,
singular.values = b1, singular.values = B1.sv,
iter = iter, iter = iter,
df = df, df = df,
loss = loss1, loss = loss1,
lambda = lambda, lambda = lambda,
AIC = loss1 / var(y) + 2 * df, AIC = loss1 + 2 * df, # TODO: check this!
BIC = loss1 / var(y) + log(nrow(X)) * df, BIC = loss1 + log(nrow(X)) * df, # TODO: check this!
call = match.call() # invocing function call, collects params like lambda call = match.call() # invocing function call, collects params like lambda
) )
} }