diff --git a/tensorPredictors/R/RMReg.R b/tensorPredictors/R/RMReg.R index d1b40ee..6ab6dec 100644 --- a/tensorPredictors/R/RMReg.R +++ b/tensorPredictors/R/RMReg.R @@ -25,11 +25,11 @@ #' @param y univariate response vector #' @param lambda penalty term, if set to \code{Inf} #' @param loss loss function, part of the objective function -#' @param grad.loss gradient with respect to \eqn{B} of the loss function -#' (required, there is no support for numerical gradients) #' @param penalty penalty function with a vector of the singular values if the #' current iterate as arguments. The default function #' \code{function(sigma) sum(sigma)} is the nuclear norm penalty. +#' @param grad.loss gradient with respect to \eqn{B} of the loss function +#' (required, there is no support for numerical gradients) #' @param shape Shape of the matrix valued predictors. Required iff the #' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix. #' @param step.size max. stepsize for gradient updates @@ -38,19 +38,23 @@ #' @param beta initial value of additional covatiates coefficient for \eqn{Z} #' @param max.iter maximum number of gadient updates #' @param max.line.iter maximum number of line search iterations +#' @param logger logging callback invoced after every line search before break +#' condition checks. The expected function signature is of the form +#' \code{logger(iter, loss, penalty, grad, B, beta, step.size)}. #' #' @export RMReg <- function(X, Z, y, lambda = 0, loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2), - grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X), penalty = function(sigma) sum(sigma), + grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X), shape = dim(X)[-1], step.size = 1e-3, alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 }, B0 = array(0, dim = shape), beta = rep(0, NCOL(Z)), max.iter = 500, - max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)) + max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)), + logger = NULL ) { ### Check (prepair) params stopifnot(nrow(X) == length(y)) @@ -145,8 +149,15 @@ RMReg <- function(X, Z, y, lambda = 0, } } - # After gradient update enforce descent (stop if not decreasing) + # Evaluate loss at (potential) new parameters loss.temp <- loss(B.temp, beta, X, Z, y) + + # logging callback + if (is.function(logger)) { + logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta) + } + + # After gradient update enforce descent (stop if not decreasing) if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { no.nesterov <- FALSE loss1 <- loss.temp