tensor_predictors/tensorPredictors/R/ICU.R

57 lines
1.9 KiB
R
Raw Normal View History

#' Iterative Cyclic (Coordinate) Update
#'
#' @param fun.loss Scalar loss function (minimization objective), its signature
#' is \code{function(params)} and return a scalar.
#' @param fun.update compute new parameter (parameter block) for the \code{index}
#' parameter (block) in \code{params}, the function signature is
#' \code{function(params, index)} and returns a parameter block corresponding
#' to \code{fun.getElement(params, index)}
#' @param params initial paramiters, a.k.a. start position
#' @param indices parameter index set used with \code{[[<-} and passed to
#' \code{fun.update}
#' @param fun.sample computes a permutation of indices. If the parameters
#' should not be permuted use \code{identity}.
#' @param max.iter maximum number of parameter update cycles
#' @param eps small constant used in break conditions
#' @param callback function invoked for each iteration (including iteration 0)
#' with the signature \code{function(iter, params)}.
#'
#' @example inst/examples/ICU.R
#'
#' @export
ICU <- function(fun.loss, fun.update, params,
indices = base::seq_along(params),
fun.sample = base::sample,
max.iter = 50L,
eps = .Machine$double.eps,
callback = NULL
) {
# Compute initial loss
loss <- fun.loss(params)
# Call callback for with initial parameters
if (is.function(callback)) callback(0L, params)
# iteration loop of parameter update cycles
for (iter in seq_len(max.iter)) {
# Random order parameter update cycle
for (index in fun.sample(indices)) {
params[[index]] <- fun.update(params, index)
}
# Call callback after each update cycle
if (is.function(callback)) callback(iter, params)
# recompute loss for brack condition
loss.last <- loss
loss <- fun.loss(params)
# and check break condition
if (abs(loss.last - loss) < eps) {
break
}
}
params
}