57 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			57 lines
		
	
	
		
			1.9 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Iterative Cyclic (Coordinate) Update
 | 
						|
#'
 | 
						|
#' @param fun.loss Scalar loss function (minimization objective), its signature
 | 
						|
#'  is \code{function(params)} and return a scalar.
 | 
						|
#' @param fun.update compute new parameter (parameter block) for the \code{index}
 | 
						|
#'  parameter (block) in \code{params}, the function signature is
 | 
						|
#'  \code{function(params, index)} and returns a parameter block corresponding
 | 
						|
#'  to \code{fun.getElement(params, index)}
 | 
						|
#' @param params initial paramiters, a.k.a. start position
 | 
						|
#' @param indices parameter index set used with \code{[[<-} and passed to
 | 
						|
#'  \code{fun.update}
 | 
						|
#' @param fun.sample computes a permutation of indices. If the parameters
 | 
						|
#'  should not be permuted use \code{identity}.
 | 
						|
#' @param max.iter maximum number of parameter update cycles
 | 
						|
#' @param eps small constant used in break conditions
 | 
						|
#' @param callback function invoked for each iteration (including iteration 0)
 | 
						|
#'  with the signature \code{function(iter, params)}.
 | 
						|
#'
 | 
						|
#' @example inst/examples/ICU.R
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
ICU <- function(fun.loss, fun.update, params,
 | 
						|
    indices = base::seq_along(params),
 | 
						|
    fun.sample = base::sample,
 | 
						|
    max.iter = 50L,
 | 
						|
    eps = .Machine$double.eps,
 | 
						|
    callback = NULL
 | 
						|
) {
 | 
						|
    # Compute initial loss
 | 
						|
    loss <- fun.loss(params)
 | 
						|
 | 
						|
    # Call callback for with initial parameters
 | 
						|
    if (is.function(callback)) callback(0L, params)
 | 
						|
 | 
						|
    # iteration loop of parameter update cycles
 | 
						|
    for (iter in seq_len(max.iter)) {
 | 
						|
        # Random order parameter update cycle
 | 
						|
        for (index in fun.sample(indices)) {
 | 
						|
            params[[index]] <- fun.update(params, index)
 | 
						|
        }
 | 
						|
 | 
						|
        # Call callback after each update cycle
 | 
						|
        if (is.function(callback)) callback(iter, params)
 | 
						|
 | 
						|
        # recompute loss for brack condition
 | 
						|
        loss.last <- loss
 | 
						|
        loss <- fun.loss(params)
 | 
						|
 | 
						|
        # and check break condition
 | 
						|
        if (abs(loss.last - loss) < eps) {
 | 
						|
            break
 | 
						|
        }
 | 
						|
   }
 | 
						|
 | 
						|
    params
 | 
						|
}
 |