Compare commits
No commits in common. "8da1950d02290632c1753612607df0713251f36c" and "938e6bd3bad73323561a5c9c2a6418820c600ece" have entirely different histories.
8da1950d02
...
938e6bd3ba
|
@ -25,11 +25,11 @@
|
||||||
#' @param y univariate response vector
|
#' @param y univariate response vector
|
||||||
#' @param lambda penalty term, if set to \code{Inf} max lambda is computed.
|
#' @param lambda penalty term, if set to \code{Inf} max lambda is computed.
|
||||||
#' @param loss loss function, part of the objective function
|
#' @param loss loss function, part of the objective function
|
||||||
|
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
|
||||||
|
#' (required, there is no support for numerical gradients)
|
||||||
#' @param penalty penalty function with a vector of the singular values if the
|
#' @param penalty penalty function with a vector of the singular values if the
|
||||||
#' current iterate as arguments. The default function
|
#' current iterate as arguments. The default function
|
||||||
#' \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
|
#' \code{function(sigma) sum(sigma)} is the nuclear norm penalty.
|
||||||
#' @param grad.loss gradient with respect to \eqn{B} of the loss function
|
|
||||||
#' (required, there is no support for numerical gradients)
|
|
||||||
#' @param shape Shape of the matrix valued predictors. Required iff the
|
#' @param shape Shape of the matrix valued predictors. Required iff the
|
||||||
#' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
|
#' predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
|
||||||
#' @param step.size max. stepsize for gradient updates
|
#' @param step.size max. stepsize for gradient updates
|
||||||
|
@ -38,23 +38,19 @@
|
||||||
#' @param beta initial value of additional covatiates coefficient for \eqn{Z}
|
#' @param beta initial value of additional covatiates coefficient for \eqn{Z}
|
||||||
#' @param max.iter maximum number of gadient updates
|
#' @param max.iter maximum number of gadient updates
|
||||||
#' @param max.line.iter maximum number of line search iterations
|
#' @param max.line.iter maximum number of line search iterations
|
||||||
#' @param logger logging callback invoced after every line search before break
|
|
||||||
#' condition checks. The expected function signature is of the form
|
|
||||||
#' \code{logger(iter, loss, penalty, grad, B, beta, step.size)}.
|
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
RMReg <- function(X, Z, y, lambda = 0,
|
RMReg <- function(X, Z, y, lambda = 0,
|
||||||
loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2),
|
loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2),
|
||||||
penalty = function(sigma) sum(sigma),
|
|
||||||
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
|
grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X),
|
||||||
|
penalty = function(sigma) sum(sigma),
|
||||||
shape = dim(X)[-1],
|
shape = dim(X)[-1],
|
||||||
step.size = 1e-3,
|
step.size = 1e-3,
|
||||||
alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
|
alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
|
||||||
B0 = array(0, dim = shape),
|
B0 = array(0, dim = shape),
|
||||||
beta = rep(0, NCOL(Z)),
|
beta = rep(0, NCOL(Z)),
|
||||||
max.iter = 500,
|
max.iter = 500,
|
||||||
max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)),
|
max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2))
|
||||||
logger = NULL
|
|
||||||
) {
|
) {
|
||||||
### Check (prepair) params
|
### Check (prepair) params
|
||||||
stopifnot(nrow(X) == length(y))
|
stopifnot(nrow(X) == length(y))
|
||||||
|
@ -150,15 +146,8 @@ RMReg <- function(X, Z, y, lambda = 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Evaluate loss at (potential) new parameters
|
|
||||||
loss.temp <- loss(B.temp, beta, X, Z, y)
|
|
||||||
|
|
||||||
# logging callback
|
|
||||||
if (is.function(logger)) {
|
|
||||||
logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta)
|
|
||||||
}
|
|
||||||
|
|
||||||
# After gradient update enforce descent (stop if not decreasing)
|
# After gradient update enforce descent (stop if not decreasing)
|
||||||
|
loss.temp <- loss(B.temp, beta, X, Z, y)
|
||||||
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
|
if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) {
|
||||||
no.nesterov <- FALSE
|
no.nesterov <- FALSE
|
||||||
loss1 <- loss.temp
|
loss1 <- loss.temp
|
||||||
|
|
Loading…
Reference in New Issue