tensor_predictors/tensorPredictors/R/NAGD.R

165 lines
6.6 KiB
R
Raw Normal View History

#' Nesterov Accelerated Gradient Descent
#'
#' Minimized \code{fun.loss} given its gradient \code{fun.grad} from initial
#' position \code{params}. This generiv implementation allows for structured
#' parameters provided that the function \code{fun.lincomb} and \code{fun.norm2}
#' can handle the parameters appropriately.
#'
#' @param fun.loss Scalar loss function (minimization objective), its signature
#' is \code{function(params)} or \code{function(params, more.params)} if
#' \code{more.params} is not missing and its return is assumed to be a scalar.
#' @param fun.grad Gradient of \code{fun.loss} with signature
#' \code{function(params)} or \code{function(params, more.params)} if
#' \code{more.params} is not missing and its return is assumed to be \code{params}.
#' @param params initial paramiters, a.k.a. start position
#' @param more.params further parameters not subject to optimization. They might
#' change during optimization as result of a call to \code{fun.more.params}.
#' @param fun.more.params function of signature
#' \code{function(params, more.params)} if \code{more.params} is not missing.
#' This is called whenever \code{params} where updated if \code{more.params}
#' are not missing.
#' @param fun.lincomb linear combination of parameters, see examples.
#' @param fun.norm2 squared norm of parameters, applied to \code{fun.grad} output
#' @param max.iter maximum number of gradient updates
#' @param max.line.iter maximum number of line search iterations
#' @param step.size initial step size, used in the first iterate as initial
#' value in the backtracking line search. Gets addapted during runtime.
#' @param armijo constant for Armijo condition in the line search
#' @param gamma line search step size reduction in the open (0, 1) interval
#' @param eps small constant used in break conditions
#' @param callback function invoked for each iteration (including iteration 0)
#' with the signature \code{function(iter, params)} or
#' \code{function(iter, params, more.params)}.
#'
#' @return Ether the final parameter estimates \code{params} or a list with
#' parameters and more parameters \code{list(params, more.params)} in case
#' of non missing \code{more.params}.
#'
#' @example inst/examples/NAGD.R
#'
#' @export
NAGD <- function(fun.loss, fun.grad, params, more.params = NULL,
fun.more.params = function(params, more.params) more.params,
fun.lincomb = function(a, params1, b, params2) a * params1 + b * params2,
fun.norm2 = function(params) sum(params^2),
max.iter = 50L, max.line.iter = 50L, step.size = 1e-2,
armijo = 0.1, gamma = 2 / (1 + sqrt(5)),
eps = sqrt(.Machine$double.eps),
callback = NULL
) {
# momentum extrapolation weights
m <- c(0, 1)
# Compute initial loss
if (missing(more.params)) {
loss <- fun.loss(params)
} else {
loss <- fun.loss(params, more.params)
}
if (!is.finite(loss)) {
stop("Initial loss is non-finite (", loss, ")")
}
# initialize "previous" iterate parameters
2023-11-14 13:35:43 +00:00
prev.params <- params
# Gradient Descent Loop
line.search.tag <- FALSE # init line search state as "failure"
for (iter in seq_len(max.iter)) {
# Call callback for previous iterate
if (missing(more.params)) {
if (is.function(callback)) callback(iter - 1L, params)
} else {
if (is.function(callback)) callback(iter - 1L, params, more.params)
}
# Extrapolation form previous position (momentum)
2023-11-14 13:35:43 +00:00
# `params.moment <- (1 + moment) * params - moment * prev.params`
moment <- (m[1] - 1) / m[2]
2023-11-14 13:35:43 +00:00
params.moment <- fun.lincomb(1 + moment, params, -moment, prev.params)
# Compute gradient at extrapolated position
if (missing(more.params)) {
gradients <- fun.grad(params.moment)
} else {
more.params <- fun.more.params(params.moment, more.params)
gradients <- fun.grad(params.moment, more.params)
}
# gradient inner product (with itself), aka squared norm
grad.inner.prod <- fun.norm2(gradients)
if (!is.finite(grad.inner.prod)) {
stop("Encountered non-finite gradient (", grad.inner.prod,
") with loss (", loss, ")")
}
# Backtracking like Line Search
for (delta in step.size * gamma^seq.int(-1L, length.out = max.line.iter)) {
# Gradient Update with current step size
params.temp <- fun.lincomb(1, params.moment, -delta, gradients)
# compute loss at temporary position
if (missing(more.params)) {
loss.temp <- fun.loss(params.temp)
} else {
more.params.temp <- fun.more.params(params.temp, more.params)
loss.temp <- fun.loss(params.temp, more.params.temp)
}
loss.temp <- if (is.finite(loss.temp)) loss.temp else Inf
# check Armijo condition at temporary position
if (loss.temp <= loss - armijo * delta * grad.inner.prod) {
line.search.tag <- TRUE
break
}
}
# keep track of previous parameters
2023-11-14 13:35:43 +00:00
prev.params <- params
# check line search outcome
if (is.na(line.search.tag)) {
# line search hopeless -> break algorithm
if (missing(more.params)) {
return(params)
} else {
return(list(params = params, more.params = more.params))
}
} else if (line.search.tag == TRUE) {
# line search success -> check break conditions
if (abs(loss - loss.temp) < eps * loss) {
break
}
# update loss and parameters
loss <- loss.temp
params <- params.temp
if (!missing(more.params)) {
more.params <- more.params.temp
}
# momentum extrapolation weights
m <- c(m[2], (1 + sqrt(1 + (2 * m[2])^2)) / 2)
# and the step size
step.size <- delta
# set line search tag to false for next step
line.search.tag <- FALSE
} else {
# line search failure -> retry without momentum
line.search.tag <- NA
next
}
}
# Call callback with final result
if (missing(more.params)) {
if (is.function(callback)) callback(iter, params)
} else {
if (is.function(callback)) callback(iter, params, more.params)
}
# return estimated parameters
if (missing(more.params)) {
return(params)
} else {
return(list(params = params, more.params = more.params))
}
}