wip: RMReg added logger

This commit is contained in:
Daniel Kapla 2022-01-13 11:21:58 +01:00
parent 74f1ce6ecf
commit dad521ee45
1 changed files with 16 additions and 5 deletions

View File

@ -25,11 +25,11 @@
#' @param y univariate response vector #' @param y univariate response vector
#' @param lambda penalty term, if set to \code{Inf} #' @param lambda penalty term, if set to \code{Inf}
#' @param loss loss function, part of the objective function #' @param loss loss function, part of the objective function
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
#' (required, there is no support for numerical gradients)
#' @param penalty penalty function with a vector of the singular values if the #' @param penalty penalty function with a vector of the singular values if the
#' current iterate as arguments. The default function #' current iterate as arguments. The default function
#' \code{function(sigma) sum(sigma)} is the nuclear norm penalty. #' \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
#' (required, there is no support for numerical gradients)
#' @param shape Shape of the matrix valued predictors. Required iff the #' @param shape Shape of the matrix valued predictors. Required iff the
#' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix. #' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
#' @param step.size max. stepsize for gradient updates #' @param step.size max. stepsize for gradient updates
@ -38,19 +38,23 @@
#' @param beta initial value of additional covatiates coefficient for \eqn{Z} #' @param beta initial value of additional covatiates coefficient for \eqn{Z}
#' @param max.iter maximum number of gadient updates #' @param max.iter maximum number of gadient updates
#' @param max.line.iter maximum number of line search iterations #' @param max.line.iter maximum number of line search iterations
#' @param logger logging callback invoced after every line search before break
#' condition checks. The expected function signature is of the form
#' \code{logger(iter, loss, penalty, grad, B, beta, step.size)}.
#' #'
#' @export #' @export
RMReg <- function(X, Z, y, lambda = 0, RMReg <- function(X, Z, y, lambda = 0,
loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2), loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2),
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
penalty = function(sigma) sum(sigma), penalty = function(sigma) sum(sigma),
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
shape = dim(X)[-1], shape = dim(X)[-1],
step.size = 1e-3, step.size = 1e-3,
alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 }, alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
B0 = array(0, dim = shape), B0 = array(0, dim = shape),
beta = rep(0, NCOL(Z)), beta = rep(0, NCOL(Z)),
max.iter = 500, max.iter = 500,
max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)) max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)),
logger = NULL
) { ) {
### Check (prepair) params ### Check (prepair) params
stopifnot(nrow(X) == length(y)) stopifnot(nrow(X) == length(y))
@ -145,8 +149,15 @@ RMReg <- function(X, Z, y, lambda = 0,
} }
} }
# After gradient update enforce descent (stop if not decreasing) # Evaluate loss at (potential) new parameters
loss.temp <- loss(B.temp, beta, X, Z, y) loss.temp <- loss(B.temp, beta, X, Z, y)
# logging callback
if (is.function(logger)) {
logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta)
}
# After gradient update enforce descent (stop if not decreasing)
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
no.nesterov <- FALSE no.nesterov <- FALSE
loss1 <- loss.temp loss1 <- loss.temp