2021-12-07 18:00:00 +00:00
|
|
|
#' Regularized Matrix Regression
|
|
|
|
#'
|
|
|
|
#' Solved the regularized problem
|
|
|
|
#' \deqn{min h(B) = l(B) + J(B)}
|
|
|
|
#' for a matrix \eqn{B}.
|
|
|
|
#' where \eqn{l} is a loss function; for the GLM, we use the negative
|
|
|
|
#' log-likelihood as the loss. \eqn{J(B) = f(\sigma(B))}, where \eqn{f} is a
|
|
|
|
#' function of the singular values of \eqn{B}.
|
|
|
|
#'
|
|
|
|
#' The default parameterization is a nuclear norm penalized least squares regression.
|
|
|
|
#'
|
|
|
|
#' The least squares loss combined with \eqn{f(s) = \lambda \sum_i |s_i|}
|
|
|
|
#' corresponds to the nuclear norm regularization problem.
|
|
|
|
#'
|
2021-12-14 19:05:47 +00:00
|
|
|
#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed.
|
|
|
|
#' In this case the return value is only estimate as a single value.
|
|
|
|
#'
|
2021-12-07 18:00:00 +00:00
|
|
|
#' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a
|
|
|
|
#' 3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the
|
|
|
|
#' first dimension are the observations while the second and third are the
|
|
|
|
#' `image' dimensions. When the data is provided as a matix it's assumed to be
|
|
|
|
#' of shape \eqn{n\times p q} where each observation is the vectorid `image'.
|
2021-12-09 12:21:38 +00:00
|
|
|
#' @param Z additional covariate vector (can be \code{NULL} if not required.
|
|
|
|
#' For regression with intercept set \code{Z = rep(1, n)})
|
2021-12-07 18:00:00 +00:00
|
|
|
#' @param y univariate response vector
|
2021-12-14 19:05:47 +00:00
|
|
|
#' @param lambda penalty term, if set to \code{Inf}
|
2021-12-09 12:21:38 +00:00
|
|
|
#' @param loss loss function, part of the objective function
|
2021-12-07 18:00:00 +00:00
|
|
|
#' @param penalty penalty function with a vector of the singular values if the
|
|
|
|
#' current iterate as arguments. The default function
|
|
|
|
#' \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
|
2022-01-13 10:21:58 +00:00
|
|
|
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
|
|
|
|
#' (required, there is no support for numerical gradients)
|
2021-12-07 18:00:00 +00:00
|
|
|
#' @param shape Shape of the matrix valued predictors. Required iff the
|
|
|
|
#' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
|
|
|
|
#' @param step.size max. stepsize for gradient updates
|
|
|
|
#' @param alpha iterative Nesterov momentum scaling values
|
|
|
|
#' @param B0 initial value for optimization. Matrix of dimensions \eqn{p\times q}
|
2021-12-09 12:21:38 +00:00
|
|
|
#' @param beta initial value of additional covatiates coefficient for \eqn{Z}
|
2021-12-07 18:00:00 +00:00
|
|
|
#' @param max.iter maximum number of gadient updates
|
|
|
|
#' @param max.line.iter maximum number of line search iterations
|
2022-01-13 10:21:58 +00:00
|
|
|
#' @param logger logging callback invoced after every line search before break
|
|
|
|
#' condition checks. The expected function signature is of the form
|
|
|
|
#' \code{logger(iter, loss, penalty, grad, B, beta, step.size)}.
|
2021-12-07 18:00:00 +00:00
|
|
|
#'
|
|
|
|
#' @export
|
2021-12-09 12:21:38 +00:00
|
|
|
RMReg <- function(X, Z, y, lambda = 0,
|
|
|
|
loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2),
|
2021-12-07 18:00:00 +00:00
|
|
|
penalty = function(sigma) sum(sigma),
|
2022-01-13 10:21:58 +00:00
|
|
|
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
|
2021-12-07 18:00:00 +00:00
|
|
|
shape = dim(X)[-1],
|
|
|
|
step.size = 1e-3,
|
|
|
|
alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
|
|
|
|
B0 = array(0, dim = shape),
|
2021-12-09 12:21:38 +00:00
|
|
|
beta = rep(0, NCOL(Z)),
|
2021-12-07 18:00:00 +00:00
|
|
|
max.iter = 500,
|
2022-01-13 10:21:58 +00:00
|
|
|
max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)),
|
|
|
|
logger = NULL
|
2021-12-07 18:00:00 +00:00
|
|
|
) {
|
|
|
|
### Check (prepair) params
|
|
|
|
stopifnot(nrow(X) == length(y))
|
|
|
|
if (!missing(shape)) {
|
|
|
|
stopifnot(ncol(X) == prod(shape))
|
|
|
|
} else {
|
|
|
|
stopifnot(length(dim(X)) == 3)
|
|
|
|
dim(X) <- c(nrow(X), prod(shape))
|
|
|
|
}
|
2021-12-09 12:21:38 +00:00
|
|
|
if (missing(Z) || is.null(Z)) {
|
|
|
|
Z <- matrix(0, nrow(X), 1)
|
|
|
|
ZZiZ <- NULL
|
|
|
|
} else {
|
|
|
|
# Compute (Z' Z)^{-1} Z used to solve for beta. This is constant
|
|
|
|
# throughout and the variable name stands for "((Z' Z) Inverse) Z"
|
|
|
|
ZZiZ <- solve(crossprod(Z, Z), t(Z))
|
|
|
|
}
|
2021-12-07 18:00:00 +00:00
|
|
|
|
|
|
|
### Set initial values
|
2021-12-14 19:05:47 +00:00
|
|
|
# Note: Naming convention; a name ending with 1 is the current iterate while
|
|
|
|
# names ending with 0 are the previous iterate value.
|
|
|
|
# Init singular values of B1 (require only current point, not previous B0)
|
2021-12-07 18:00:00 +00:00
|
|
|
if (missing(B0)) {
|
|
|
|
b1 <- rep(0, min(shape))
|
|
|
|
} else {
|
|
|
|
b1 <- La.svd(B0, 0, 0)$d
|
|
|
|
}
|
2021-12-14 19:05:47 +00:00
|
|
|
# Init current to previous (start position)
|
2021-12-07 18:00:00 +00:00
|
|
|
B1 <- B0
|
|
|
|
a0 <- 0
|
|
|
|
a1 <- 1
|
2021-12-09 12:21:38 +00:00
|
|
|
loss1 <- loss(B1, beta, X, Z, y)
|
2021-12-07 18:00:00 +00:00
|
|
|
|
2021-12-14 19:05:47 +00:00
|
|
|
# Start without, the nesterov momentum is zero anyway
|
|
|
|
no.nesterov <- TRUE
|
2021-12-07 18:00:00 +00:00
|
|
|
### Repeat untill convergence
|
2021-12-09 12:21:38 +00:00
|
|
|
for (iter in 1:max.iter) {
|
2021-12-14 19:05:47 +00:00
|
|
|
# Extrapolation with Nesterov Momentum
|
|
|
|
if (no.nesterov) {
|
|
|
|
S <- B1
|
|
|
|
} else {
|
|
|
|
S <- B1 + ((a0 - 1) / a1) * (B1 - B0)
|
|
|
|
}
|
2021-12-07 18:00:00 +00:00
|
|
|
|
2021-12-09 12:21:38 +00:00
|
|
|
# Solve for beta at extrapolation point
|
|
|
|
if (!is.null(ZZiZ)) {
|
|
|
|
beta <- ZZiZ %*% (y - X %*% c(S))
|
|
|
|
}
|
|
|
|
|
2021-12-07 18:00:00 +00:00
|
|
|
# Compute Nesterov Gradient of the Loss
|
2021-12-09 12:21:38 +00:00
|
|
|
grad <- array(grad.loss(S, beta, X, Z, y), dim = shape)
|
2021-12-07 18:00:00 +00:00
|
|
|
|
|
|
|
# Line Search (executed at least once)
|
|
|
|
for (delta in step.size * 0.5^seq(0, max.line.iter - 1)) {
|
|
|
|
# (potential) next step with delta as stepsize for gradient update
|
|
|
|
A <- S - delta * grad
|
|
|
|
|
2021-12-14 19:05:47 +00:00
|
|
|
if (lambda == Inf) {
|
|
|
|
# Application of Corollary 1 (only nuclear norm supported) to
|
|
|
|
# estimate maximum lambda. In this case (first time this line is
|
|
|
|
# hit when lambda set to Inf, then B is zero (ignore B0 param))
|
|
|
|
lambda.max <- max(La.svd(A, 0, 0)$d) / delta
|
|
|
|
return(lambda.max)
|
|
|
|
} else if (lambda > 0) {
|
2021-12-07 18:00:00 +00:00
|
|
|
# SVD of (potential) next step
|
|
|
|
svdA <- La.svd(A)
|
|
|
|
|
|
|
|
# Get (potential) next penalized iterate (nuclear norm version only)
|
2021-12-14 19:05:47 +00:00
|
|
|
b.temp <- pmax(0, svdA$d - delta * lambda) # Singular values of B.temp
|
2021-12-07 18:00:00 +00:00
|
|
|
B.temp <- svdA$u %*% (b.temp * svdA$vt)
|
|
|
|
} else {
|
|
|
|
# in case of no penalization (pure least squares solution)
|
|
|
|
b.temp <- La.svd(A, 0, 0)$d
|
|
|
|
B.temp <- A
|
|
|
|
}
|
|
|
|
|
2021-12-09 12:21:38 +00:00
|
|
|
# Solve for beta at (potential) next step
|
|
|
|
if (!is.null(ZZiZ)) {
|
|
|
|
beta <- ZZiZ %*% (y - X %*% c(B.temp))
|
|
|
|
}
|
|
|
|
|
2021-12-07 18:00:00 +00:00
|
|
|
# Check line search break condition
|
|
|
|
# h(B.temp) <= g(B.temp | S, delta)
|
|
|
|
# \_ left _/ \_____ right _____/
|
|
|
|
# where g(B.temp | S, delta) is the first order approx. of the loss
|
|
|
|
# l(S) + <grad l(S), B - S> + | B - S |_F^2 / 2 delta + J(B)
|
2021-12-09 12:21:38 +00:00
|
|
|
left <- loss(B.temp, beta, X, Z, y) # + penalty(b.temp)
|
|
|
|
right <- loss(S, beta, X, Z, y) + sum(grad * (B1 - S)) +
|
2021-12-07 18:00:00 +00:00
|
|
|
norm(B1 - S, 'F')^2 / (2 * delta) # + penalty(b.temp)
|
|
|
|
if (left <= right) {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-01-13 10:21:58 +00:00
|
|
|
# Evaluate loss at (potential) new parameters
|
2021-12-09 12:21:38 +00:00
|
|
|
loss.temp <- loss(B.temp, beta, X, Z, y)
|
2022-01-13 10:21:58 +00:00
|
|
|
|
|
|
|
# logging callback
|
|
|
|
if (is.function(logger)) {
|
|
|
|
logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta)
|
|
|
|
}
|
|
|
|
|
|
|
|
# After gradient update enforce descent (stop if not decreasing)
|
2021-12-07 18:00:00 +00:00
|
|
|
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
|
2021-12-14 19:05:47 +00:00
|
|
|
no.nesterov <- FALSE
|
2021-12-07 18:00:00 +00:00
|
|
|
loss1 <- loss.temp
|
|
|
|
B0 <- B1
|
|
|
|
B1 <- B.temp
|
2021-12-09 12:21:38 +00:00
|
|
|
b1 <- b.temp
|
2021-12-14 19:05:47 +00:00
|
|
|
} else if (!no.nesterov) {
|
|
|
|
# Retry without Nesterov extrapolation
|
|
|
|
no.nesterov <- TRUE
|
|
|
|
next
|
2021-12-07 18:00:00 +00:00
|
|
|
} else {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
|
2021-12-14 19:05:47 +00:00
|
|
|
# If estimate is zero, stop algorithm
|
|
|
|
if (all(b.temp < .Machine$double.eps)) {
|
|
|
|
loss1 <- loss.temp
|
|
|
|
B1 <- array(0, dim = shape)
|
|
|
|
b1 <- rep(0, min(shape))
|
2021-12-09 12:21:38 +00:00
|
|
|
break
|
|
|
|
}
|
|
|
|
|
2021-12-07 18:00:00 +00:00
|
|
|
# Update momentum scaling
|
|
|
|
a0 <- a1
|
|
|
|
a1 <- alpha(a1, iter)
|
|
|
|
}
|
|
|
|
|
2021-12-09 17:29:04 +00:00
|
|
|
### Degrees of Freedom estimate (TODO: this is like in `matrix_sparsereg.m`)
|
|
|
|
sigma <- c(La.svd(A, 0, 0)$d, rep(0, max(shape) - min(shape)))
|
|
|
|
df <- if (!is.null(ZZiZ)) { ncol(Z) } else { 0 }
|
|
|
|
for (i in seq_len(sum(b1 > 0))) {
|
|
|
|
df <- df + 1 + sigma[i] * (sigma[i] - delta * lambda) * (
|
|
|
|
sum(ifelse((1:shape[1]) != i, 1 / (sigma[i]^2 - sigma[1:shape[1]]^2), 0)) +
|
|
|
|
sum(ifelse((1:shape[2]) != i, 1 / (sigma[i]^2 - sigma[1:shape[2]]^2), 0)))
|
|
|
|
}
|
2021-12-09 12:21:38 +00:00
|
|
|
|
|
|
|
# return estimates and some additional stats
|
|
|
|
list(
|
|
|
|
B = B1,
|
|
|
|
beta = if(is.null(ZZiZ)) { NULL } else { beta },
|
|
|
|
singular.values = b1,
|
|
|
|
iter = iter,
|
2021-12-09 17:29:04 +00:00
|
|
|
df = df,
|
|
|
|
loss = loss1,
|
2021-12-14 19:05:47 +00:00
|
|
|
lambda = lambda,
|
2021-12-09 17:29:04 +00:00
|
|
|
AIC = loss1 / var(y) + 2 * df,
|
|
|
|
BIC = loss1 / var(y) + log(nrow(X)) * df,
|
2021-12-09 12:21:38 +00:00
|
|
|
call = match.call() # invocing function call, collects params like lambda
|
|
|
|
)
|
2021-12-07 18:00:00 +00:00
|
|
|
}
|