165 lines
6.6 KiB
R
165 lines
6.6 KiB
R
#' Nesterov Accelerated Gradient Descent
|
|
#'
|
|
#' Minimized \code{fun.loss} given its gradient \code{fun.grad} from initial
|
|
#' position \code{params}. This generiv implementation allows for structured
|
|
#' parameters provided that the function \code{fun.lincomb} and \code{fun.norm2}
|
|
#' can handle the parameters appropriately.
|
|
#'
|
|
#' @param fun.loss Scalar loss function (minimization objective), its signature
|
|
#' is \code{function(params)} or \code{function(params, more.params)} if
|
|
#' \code{more.params} is not missing and its return is assumed to be a scalar.
|
|
#' @param fun.grad Gradient of \code{fun.loss} with signature
|
|
#' \code{function(params)} or \code{function(params, more.params)} if
|
|
#' \code{more.params} is not missing and its return is assumed to be \code{params}.
|
|
#' @param params initial paramiters, a.k.a. start position
|
|
#' @param more.params further parameters not subject to optimization. They might
|
|
#' change during optimization as result of a call to \code{fun.more.params}.
|
|
#' @param fun.more.params function of signature
|
|
#' \code{function(params, more.params)} if \code{more.params} is not missing.
|
|
#' This is called whenever \code{params} where updated if \code{more.params}
|
|
#' are not missing.
|
|
#' @param fun.lincomb linear combination of parameters, see examples.
|
|
#' @param fun.norm2 squared norm of parameters, applied to \code{fun.grad} output
|
|
#' @param max.iter maximum number of gradient updates
|
|
#' @param max.line.iter maximum number of line search iterations
|
|
#' @param step.size initial step size, used in the first iterate as initial
|
|
#' value in the backtracking line search. Gets addapted during runtime.
|
|
#' @param armijo constant for Armijo condition in the line search
|
|
#' @param gamma line search step size reduction in the open (0, 1) interval
|
|
#' @param eps small constant used in break conditions
|
|
#' @param callback function invoked for each iteration (including iteration 0)
|
|
#' with the signature \code{function(iter, params)} or
|
|
#' \code{function(iter, params, more.params)}.
|
|
#'
|
|
#' @return Ether the final parameter estimates \code{params} or a list with
|
|
#' parameters and more parameters \code{list(params, more.params)} in case
|
|
#' of non missing \code{more.params}.
|
|
#'
|
|
#' @example inst/examples/NAGD.R
|
|
#'
|
|
#' @export
|
|
NAGD <- function(fun.loss, fun.grad, params, more.params = NULL,
|
|
fun.more.params = function(params, more.params) more.params,
|
|
fun.lincomb = function(a, params1, b, params2) a * params1 + b * params2,
|
|
fun.norm2 = function(params) sum(params^2),
|
|
max.iter = 50L, max.line.iter = 50L, step.size = 1e-2,
|
|
armijo = 0.1, gamma = 2 / (1 + sqrt(5)),
|
|
eps = sqrt(.Machine$double.eps),
|
|
callback = NULL
|
|
) {
|
|
# momentum extrapolation weights
|
|
m <- c(0, 1)
|
|
|
|
# Compute initial loss
|
|
if (missing(more.params)) {
|
|
loss <- fun.loss(params)
|
|
} else {
|
|
loss <- fun.loss(params, more.params)
|
|
}
|
|
if (!is.finite(loss)) {
|
|
stop("Initial loss is non-finite (", loss, ")")
|
|
}
|
|
# initialize "previous" iterate parameters
|
|
params.last <- params
|
|
|
|
# Gradient Descent Loop
|
|
line.search.tag <- FALSE # init line search state as "failure"
|
|
for (iter in seq_len(max.iter)) {
|
|
# Call callback for previous iterate
|
|
if (missing(more.params)) {
|
|
if (is.function(callback)) callback(iter - 1L, params)
|
|
} else {
|
|
if (is.function(callback)) callback(iter - 1L, params, more.params)
|
|
}
|
|
|
|
# Extrapolation form previous position (momentum)
|
|
# `params.moment <- (1 + moment) * params - moment * param.last`
|
|
moment <- (m[1] - 1) / m[2]
|
|
params.moment <- fun.lincomb(1 + moment, params, -moment, params.last)
|
|
|
|
# Compute gradient at extrapolated position
|
|
if (missing(more.params)) {
|
|
gradients <- fun.grad(params.moment)
|
|
} else {
|
|
more.params <- fun.more.params(params.moment, more.params)
|
|
gradients <- fun.grad(params.moment, more.params)
|
|
}
|
|
|
|
# gradient inner product (with itself), aka squared norm
|
|
grad.inner.prod <- fun.norm2(gradients)
|
|
if (!is.finite(grad.inner.prod)) {
|
|
stop("Encountered non-finite gradient (", grad.inner.prod,
|
|
") with loss (", loss, ")")
|
|
}
|
|
|
|
# Backtracking like Line Search
|
|
for (delta in step.size * gamma^seq.int(-1L, length.out = max.line.iter)) {
|
|
# Gradient Update with current step size
|
|
params.temp <- fun.lincomb(1, params.moment, -delta, gradients)
|
|
|
|
# compute loss at temporary position
|
|
if (missing(more.params)) {
|
|
loss.temp <- fun.loss(params.temp)
|
|
} else {
|
|
more.params.temp <- fun.more.params(params.temp, more.params)
|
|
loss.temp <- fun.loss(params.temp, more.params.temp)
|
|
}
|
|
loss.temp <- if (is.finite(loss.temp)) loss.temp else Inf
|
|
|
|
# check Armijo condition at temporary position
|
|
if (loss.temp <= loss - armijo * delta * grad.inner.prod) {
|
|
line.search.tag <- TRUE
|
|
break
|
|
}
|
|
}
|
|
|
|
# keep track of previous parameters
|
|
params.last <- params
|
|
|
|
# check line search outcome
|
|
if (is.na(line.search.tag)) {
|
|
# line search hopeless -> break algorithm
|
|
if (missing(more.params)) {
|
|
return(params)
|
|
} else {
|
|
return(list(params = params, more.params = more.params))
|
|
}
|
|
} else if (line.search.tag == TRUE) {
|
|
# line search success -> check break conditions
|
|
if (abs(loss - loss.temp) < eps * loss) {
|
|
break
|
|
}
|
|
# update loss and parameters
|
|
loss <- loss.temp
|
|
params <- params.temp
|
|
if (!missing(more.params)) {
|
|
more.params <- more.params.temp
|
|
}
|
|
# momentum extrapolation weights
|
|
m <- c(m[2], (1 + sqrt(1 + (2 * m[2])^2)) / 2)
|
|
# and the step size
|
|
step.size <- delta
|
|
# set line search tag to false for next step
|
|
line.search.tag <- FALSE
|
|
} else {
|
|
# line search failure -> retry without momentum
|
|
line.search.tag <- NA
|
|
next
|
|
}
|
|
}
|
|
|
|
# Call callback with final result
|
|
if (missing(more.params)) {
|
|
if (is.function(callback)) callback(iter, params)
|
|
} else {
|
|
if (is.function(callback)) callback(iter, params, more.params)
|
|
}
|
|
|
|
# return estimated parameters
|
|
if (missing(more.params)) {
|
|
return(params)
|
|
} else {
|
|
return(list(params = params, more.params = more.params))
|
|
}
|
|
}
|