57 lines
1.9 KiB
R
57 lines
1.9 KiB
R
|
#' Iterative Cyclic (Coordinate) Update
|
||
|
#'
|
||
|
#' @param fun.loss Scalar loss function (minimization objective), its signature
|
||
|
#' is \code{function(params)} and return a scalar.
|
||
|
#' @param fun.update compute new parameter (parameter block) for the \code{index}
|
||
|
#' parameter (block) in \code{params}, the function signature is
|
||
|
#' \code{function(params, index)} and returns a parameter block corresponding
|
||
|
#' to \code{fun.getElement(params, index)}
|
||
|
#' @param params initial paramiters, a.k.a. start position
|
||
|
#' @param indices parameter index set used with \code{[[<-} and passed to
|
||
|
#' \code{fun.update}
|
||
|
#' @param fun.sample computes a permutation of indices. If the parameters
|
||
|
#' should not be permuted use \code{identity}.
|
||
|
#' @param max.iter maximum number of parameter update cycles
|
||
|
#' @param eps small constant used in break conditions
|
||
|
#' @param callback function invoked for each iteration (including iteration 0)
|
||
|
#' with the signature \code{function(iter, params)}.
|
||
|
#'
|
||
|
#' @example inst/examples/ICU.R
|
||
|
#'
|
||
|
#' @export
|
||
|
ICU <- function(fun.loss, fun.update, params,
|
||
|
indices = base::seq_along(params),
|
||
|
fun.sample = base::sample,
|
||
|
max.iter = 50L,
|
||
|
eps = .Machine$double.eps,
|
||
|
callback = NULL
|
||
|
) {
|
||
|
# Compute initial loss
|
||
|
loss <- fun.loss(params)
|
||
|
|
||
|
# Call callback for with initial parameters
|
||
|
if (is.function(callback)) callback(0L, params)
|
||
|
|
||
|
# iteration loop of parameter update cycles
|
||
|
for (iter in seq_len(max.iter)) {
|
||
|
# Random order parameter update cycle
|
||
|
for (index in fun.sample(indices)) {
|
||
|
params[[index]] <- fun.update(params, index)
|
||
|
}
|
||
|
|
||
|
# Call callback after each update cycle
|
||
|
if (is.function(callback)) callback(iter, params)
|
||
|
|
||
|
# recompute loss for brack condition
|
||
|
loss.last <- loss
|
||
|
loss <- fun.loss(params)
|
||
|
|
||
|
# and check break condition
|
||
|
if (abs(loss.last - loss) < eps) {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
params
|
||
|
}
|