From dad521ee45e2bcb1cda5c1b2a0fe427c312b908a Mon Sep 17 00:00:00 2001
From: daniel <daniel@kapla.at>
Date: Thu, 13 Jan 2022 11:21:58 +0100
Subject: [PATCH] wip: RMReg added logger

---
 tensorPredictors/R/RMReg.R | 21 ++++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/tensorPredictors/R/RMReg.R b/tensorPredictors/R/RMReg.R
index d1b40ee..6ab6dec 100644
--- a/tensorPredictors/R/RMReg.R
+++ b/tensorPredictors/R/RMReg.R
@@ -25,11 +25,11 @@
 #' @param y univariate response vector
 #' @param lambda penalty term, if set to \code{Inf} 
 #' @param loss loss function, part of the objective function
-#' @param grad.loss gradient with respect to \eqn{B} of the loss function
-#'  (required, there is no support for numerical gradients)
 #' @param penalty penalty function with a vector of the singular values if the
 #'  current iterate as arguments. The default function
 #'  \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
+#' @param grad.loss gradient with respect to \eqn{B} of the loss function
+#'  (required, there is no support for numerical gradients)
 #' @param shape Shape of the matrix valued predictors. Required iff the
 #'  predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
 #' @param step.size max. stepsize for gradient updates
@@ -38,19 +38,23 @@
 #' @param beta initial value of additional covatiates coefficient for \eqn{Z}
 #' @param max.iter maximum number of gadient updates
 #' @param max.line.iter maximum number of line search iterations
+#' @param logger logging callback invoced after every line search before break
+#'  condition checks. The expected function signature is of the form
+#'  \code{logger(iter, loss, penalty, grad, B, beta, step.size)}.
 #'
 #' @export
 RMReg <- function(X, Z, y, lambda = 0,
     loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2),
-    grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
     penalty = function(sigma) sum(sigma),
+    grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
     shape = dim(X)[-1],
     step.size = 1e-3,
     alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
     B0 = array(0, dim = shape),
     beta = rep(0, NCOL(Z)),
     max.iter = 500,
-    max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2))
+    max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)),
+    logger = NULL
 ) {
     ### Check (prepair) params
     stopifnot(nrow(X) == length(y))
@@ -145,8 +149,15 @@ RMReg <- function(X, Z, y, lambda = 0,
             }
         }
 
-        # After gradient update enforce descent (stop if not decreasing)
+        # Evaluate loss at (potential) new parameters
         loss.temp <- loss(B.temp, beta, X, Z, y)
+
+        # logging callback
+        if (is.function(logger)) {
+            logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta)
+        }
+
+        # After gradient update enforce descent (stop if not decreasing)
         if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
             no.nesterov <- FALSE
             loss1 <- loss.temp