#' Nesterov Accelerated Gradient Descent #' #' Minimized \code{fun.loss} given its gradient \code{fun.grad} from initial #' position \code{params}. This generiv implementation allows for structured #' parameters provided that the function \code{fun.lincomb} and \code{fun.norm2} #' can handle the parameters appropriately. #' #' @param fun.loss Scalar loss function (minimization objective), its signature #' is \code{function(params)} or \code{function(params, more.params)} if #' \code{more.params} is not missing and its return is assumed to be a scalar. #' @param fun.grad Gradient of \code{fun.loss} with signature #' \code{function(params)} or \code{function(params, more.params)} if #' \code{more.params} is not missing and its return is assumed to be \code{params}. #' @param params initial paramiters, a.k.a. start position #' @param more.params further parameters not subject to optimization. They might #' change during optimization as result of a call to \code{fun.more.params}. #' @param fun.more.params function of signature #' \code{function(params, more.params)} if \code{more.params} is not missing. #' This is called whenever \code{params} where updated if \code{more.params} #' are not missing. #' @param fun.lincomb linear combination of parameters, see examples. #' @param fun.norm2 squared norm of parameters, applied to \code{fun.grad} output #' @param max.iter maximum number of gradient updates #' @param max.line.iter maximum number of line search iterations #' @param step.size initial step size, used in the first iterate as initial #' value in the backtracking line search. Gets addapted during runtime. #' @param armijo constant for Armijo condition in the line search #' @param gamma line search step size reduction in the open (0, 1) interval #' @param eps small constant used in break conditions #' @param callback function invoked for each iteration (including iteration 0) #' with the signature \code{function(iter, params)} or #' \code{function(iter, params, more.params)}. #' #' @return Ether the final parameter estimates \code{params} or a list with #' parameters and more parameters \code{list(params, more.params)} in case #' of non missing \code{more.params}. #' #' @example inst/examples/NAGD.R #' #' @export NAGD <- function(fun.loss, fun.grad, params, more.params = NULL, fun.more.params = function(params, more.params) more.params, fun.lincomb = function(a, params1, b, params2) a * params1 + b * params2, fun.norm2 = function(params) sum(params^2), max.iter = 50L, max.line.iter = 50L, step.size = 1e-2, armijo = 0.1, gamma = 2 / (1 + sqrt(5)), eps = sqrt(.Machine$double.eps), callback = NULL ) { # momentum extrapolation weights m <- c(0, 1) # Compute initial loss if (missing(more.params)) { loss <- fun.loss(params) } else { loss <- fun.loss(params, more.params) } if (!is.finite(loss)) { stop("Initial loss is non-finite (", loss, ")") } # initialize "previous" iterate parameters prev.params <- params # Gradient Descent Loop line.search.tag <- FALSE # init line search state as "failure" for (iter in seq_len(max.iter)) { # Call callback for previous iterate if (missing(more.params)) { if (is.function(callback)) callback(iter - 1L, params) } else { if (is.function(callback)) callback(iter - 1L, params, more.params) } # Extrapolation form previous position (momentum) # `params.moment <- (1 + moment) * params - moment * prev.params` moment <- (m[1] - 1) / m[2] params.moment <- fun.lincomb(1 + moment, params, -moment, prev.params) # Compute gradient at extrapolated position if (missing(more.params)) { gradients <- fun.grad(params.moment) } else { more.params <- fun.more.params(params.moment, more.params) gradients <- fun.grad(params.moment, more.params) } # gradient inner product (with itself), aka squared norm grad.inner.prod <- fun.norm2(gradients) if (!is.finite(grad.inner.prod)) { stop("Encountered non-finite gradient (", grad.inner.prod, ") with loss (", loss, ")") } # Backtracking like Line Search for (delta in step.size * gamma^seq.int(-1L, length.out = max.line.iter)) { # Gradient Update with current step size params.temp <- fun.lincomb(1, params.moment, -delta, gradients) # compute loss at temporary position if (missing(more.params)) { loss.temp <- fun.loss(params.temp) } else { more.params.temp <- fun.more.params(params.temp, more.params) loss.temp <- fun.loss(params.temp, more.params.temp) } loss.temp <- if (is.finite(loss.temp)) loss.temp else Inf # check Armijo condition at temporary position if (loss.temp <= loss - armijo * delta * grad.inner.prod) { line.search.tag <- TRUE break } } # keep track of previous parameters prev.params <- params # check line search outcome if (is.na(line.search.tag)) { # line search hopeless -> break algorithm if (missing(more.params)) { return(params) } else { return(list(params = params, more.params = more.params)) } } else if (line.search.tag == TRUE) { # line search success -> check break conditions if (abs(loss - loss.temp) < eps * loss) { break } # update loss and parameters loss <- loss.temp params <- params.temp if (!missing(more.params)) { more.params <- more.params.temp } # momentum extrapolation weights m <- c(m[2], (1 + sqrt(1 + (2 * m[2])^2)) / 2) # and the step size step.size <- delta # set line search tag to false for next step line.search.tag <- FALSE } else { # line search failure -> retry without momentum line.search.tag <- NA next } } # Call callback with final result if (missing(more.params)) { if (is.function(callback)) callback(iter, params) } else { if (is.function(callback)) callback(iter, params, more.params) } # return estimated parameters if (missing(more.params)) { return(params) } else { return(list(params = params, more.params = more.params)) } }