165 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			165 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Nesterov Accelerated Gradient Descent
 | 
						|
#'
 | 
						|
#' Minimized \code{fun.loss} given its gradient \code{fun.grad} from initial
 | 
						|
#' position \code{params}. This generiv implementation allows for structured
 | 
						|
#' parameters provided that the function \code{fun.lincomb} and \code{fun.norm2}
 | 
						|
#' can handle the parameters appropriately.
 | 
						|
#'
 | 
						|
#' @param fun.loss Scalar loss function (minimization objective), its signature
 | 
						|
#'  is \code{function(params)} or \code{function(params, more.params)} if
 | 
						|
#'  \code{more.params} is not missing and its return is assumed to be a scalar.
 | 
						|
#' @param fun.grad Gradient of \code{fun.loss} with signature
 | 
						|
#'  \code{function(params)} or \code{function(params, more.params)} if
 | 
						|
#'  \code{more.params} is not missing and its return is assumed to be \code{params}.
 | 
						|
#' @param params initial paramiters, a.k.a. start position
 | 
						|
#' @param more.params further parameters not subject to optimization. They might
 | 
						|
#'  change during optimization as result of a call to \code{fun.more.params}.
 | 
						|
#' @param fun.more.params function of signature
 | 
						|
#'  \code{function(params, more.params)} if \code{more.params} is not missing.
 | 
						|
#'  This is called whenever \code{params} where updated if \code{more.params}
 | 
						|
#'  are not missing.
 | 
						|
#' @param fun.lincomb linear combination of parameters, see examples.
 | 
						|
#' @param fun.norm2 squared norm of parameters, applied to \code{fun.grad} output
 | 
						|
#' @param max.iter maximum number of gradient updates
 | 
						|
#' @param max.line.iter maximum number of line search iterations
 | 
						|
#' @param step.size initial step size, used in the first iterate as initial
 | 
						|
#'  value in the backtracking line search. Gets addapted during runtime.
 | 
						|
#' @param armijo constant for Armijo condition in the line search
 | 
						|
#' @param gamma line search step size reduction in the open (0, 1) interval
 | 
						|
#' @param eps small constant used in break conditions
 | 
						|
#' @param callback function invoked for each iteration (including iteration 0)
 | 
						|
#'  with the signature \code{function(iter, params)} or
 | 
						|
#'  \code{function(iter, params, more.params)}.
 | 
						|
#'
 | 
						|
#' @return Ether the final parameter estimates \code{params} or a list with
 | 
						|
#'  parameters and more parameters \code{list(params, more.params)} in case
 | 
						|
#'  of non missing \code{more.params}.
 | 
						|
#'
 | 
						|
#' @example inst/examples/NAGD.R
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
NAGD <- function(fun.loss, fun.grad, params, more.params = NULL,
 | 
						|
    fun.more.params = function(params, more.params) more.params,
 | 
						|
    fun.lincomb = function(a, params1, b, params2) a * params1 + b * params2,
 | 
						|
    fun.norm2 = function(params) sum(params^2),
 | 
						|
    max.iter = 50L, max.line.iter = 50L, step.size = 1e-2,
 | 
						|
    armijo = 0.1, gamma = 2 / (1 + sqrt(5)),
 | 
						|
    eps = sqrt(.Machine$double.eps),
 | 
						|
    callback = NULL
 | 
						|
) {
 | 
						|
    # momentum extrapolation weights
 | 
						|
    m <- c(0, 1)
 | 
						|
 | 
						|
    # Compute initial loss
 | 
						|
    if (missing(more.params)) {
 | 
						|
        loss <- fun.loss(params)
 | 
						|
    } else {
 | 
						|
        loss <- fun.loss(params, more.params)
 | 
						|
    }
 | 
						|
    if (!is.finite(loss)) {
 | 
						|
        stop("Initial loss is non-finite (", loss, ")")
 | 
						|
    }
 | 
						|
    # initialize "previous" iterate parameters
 | 
						|
    prev.params <- params
 | 
						|
 | 
						|
    # Gradient Descent Loop
 | 
						|
    line.search.tag <- FALSE       # init line search state as "failure"
 | 
						|
    for (iter in seq_len(max.iter)) {
 | 
						|
        # Call callback for previous iterate
 | 
						|
        if (missing(more.params)) {
 | 
						|
            if (is.function(callback)) callback(iter - 1L, params)
 | 
						|
        } else {
 | 
						|
            if (is.function(callback)) callback(iter - 1L, params, more.params)
 | 
						|
        }
 | 
						|
 | 
						|
        # Extrapolation form previous position (momentum)
 | 
						|
        # `params.moment <- (1 + moment) * params - moment * prev.params`
 | 
						|
        moment <- (m[1] - 1) / m[2]
 | 
						|
        params.moment <- fun.lincomb(1 + moment, params, -moment, prev.params)
 | 
						|
 | 
						|
        # Compute gradient at extrapolated position
 | 
						|
        if (missing(more.params)) {
 | 
						|
            gradients <- fun.grad(params.moment)
 | 
						|
        } else {
 | 
						|
            more.params <- fun.more.params(params.moment, more.params)
 | 
						|
            gradients <- fun.grad(params.moment, more.params)
 | 
						|
        }
 | 
						|
 | 
						|
        # gradient inner product (with itself), aka squared norm
 | 
						|
        grad.inner.prod <- fun.norm2(gradients)
 | 
						|
        if (!is.finite(grad.inner.prod)) {
 | 
						|
            stop("Encountered non-finite gradient (", grad.inner.prod,
 | 
						|
                 ") with loss (", loss, ")")
 | 
						|
        }
 | 
						|
 | 
						|
        # Backtracking like Line Search
 | 
						|
        for (delta in step.size * gamma^seq.int(-1L, length.out = max.line.iter)) {
 | 
						|
            # Gradient Update with current step size
 | 
						|
            params.temp <- fun.lincomb(1, params.moment, -delta, gradients)
 | 
						|
 | 
						|
            # compute loss at temporary position
 | 
						|
            if (missing(more.params)) {
 | 
						|
                loss.temp <- fun.loss(params.temp)
 | 
						|
            } else {
 | 
						|
                more.params.temp <- fun.more.params(params.temp, more.params)
 | 
						|
                loss.temp <- fun.loss(params.temp, more.params.temp)
 | 
						|
            }
 | 
						|
            loss.temp <- if (is.finite(loss.temp)) loss.temp else Inf
 | 
						|
 | 
						|
            # check Armijo condition at temporary position
 | 
						|
            if (loss.temp <= loss - armijo * delta * grad.inner.prod) {
 | 
						|
                line.search.tag <- TRUE
 | 
						|
                break
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        # keep track of previous parameters
 | 
						|
        prev.params <- params
 | 
						|
 | 
						|
        # check line search outcome
 | 
						|
        if (is.na(line.search.tag)) {
 | 
						|
            # line search hopeless -> break algorithm
 | 
						|
            if (missing(more.params)) {
 | 
						|
                return(params)
 | 
						|
            } else {
 | 
						|
                return(list(params = params, more.params = more.params))
 | 
						|
            }
 | 
						|
        } else if (line.search.tag == TRUE) {
 | 
						|
            # line search success -> check break conditions
 | 
						|
            if (abs(loss - loss.temp) < eps * loss) {
 | 
						|
                break
 | 
						|
            }
 | 
						|
            # update loss and parameters
 | 
						|
            loss <- loss.temp
 | 
						|
            params <- params.temp
 | 
						|
            if (!missing(more.params)) {
 | 
						|
                more.params <- more.params.temp
 | 
						|
            }
 | 
						|
            # momentum extrapolation weights
 | 
						|
            m <- c(m[2], (1 + sqrt(1 + (2 * m[2])^2)) / 2)
 | 
						|
            # and the step size
 | 
						|
            step.size <- delta
 | 
						|
            # set line search tag to false for next step
 | 
						|
            line.search.tag <- FALSE
 | 
						|
        } else {
 | 
						|
            # line search failure -> retry without momentum
 | 
						|
            line.search.tag <- NA
 | 
						|
            next
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    # Call callback with final result
 | 
						|
    if (missing(more.params)) {
 | 
						|
        if (is.function(callback)) callback(iter, params)
 | 
						|
    } else {
 | 
						|
        if (is.function(callback)) callback(iter, params, more.params)
 | 
						|
    }
 | 
						|
 | 
						|
    # return estimated parameters
 | 
						|
    if (missing(more.params)) {
 | 
						|
        return(params)
 | 
						|
    } else {
 | 
						|
        return(list(params = params, more.params = more.params))
 | 
						|
    }
 | 
						|
}
 |