2
0
Fork 0
CVE/CVE/R/CVE.R

160 lines
4.8 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#' Implementation of the CVE method.
#'
#' Conditional Variance Estimator (CVE) is a novel sufficient dimension
#' reduction (SDR) method assuming a model
#' \deqn{Y \sim g(B'X) + \epsilon}{Y ~ g(B'X) + epsilon}
#' where B'X is a lower dimensional projection of the predictors.
#'
#' @param formula Formel for the regression model defining `X`, `Y`.
#' See: \code{\link{formula}}.
#' @param data data.frame holding data for formula.
#' @param method The different only differe in the used optimization.
#' All of them are Gradient based optimization on a Stiefel manifold.
#' \itemize{
#' \item "simple" Simple reduction of stepsize.
#' \item "linesearch" determines stepsize with backtracking linesearch
#' using Armijo-Wolf conditions.
#' \item TODO: further
#' }
#' @param ... Further parameters depending on the used method.
#' TODO: See ...
#' @examples
#' library(CVE)
#'
#' # sample dataset
#' ds <- dataset("M5")
#'
#' # call ´cve´ with default method (aka "simple")
#' dr.simple <- cve(ds$Y ~ ds$X, k = ncol(ds$B))
#' # plot optimization history (loss via iteration)
#' plot(dr.simple, main = "CVE M5 simple")
#'
#' # call ´cve´ with method "linesearch" using ´data.frame´ as data.
#' data <- data.frame(Y = ds$Y, X = ds$X)
#' # Note: ´Y, X´ are NOT defined, they are extracted from ´data´.
#' dr.linesearch <- cve(Y ~ ., data, method = "linesearch", k = ncol(ds$B))
#' plot(dr.linesearch, main = "CVE M5 linesearch")
#'
#' @references Fertl L., Bura E. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
#'
#' @import stats
#' @importFrom stats model.frame
#' @export
cve <- function(formula, data, method = "simple", ...) {
# check for type of `data` if supplied and set default
if (missing(data)) {
data <- environment(formula)
} else if (!is.data.frame(data)) {
stop('Parameter `data` must be a `data.frame` or missing.')
}
# extract `X`, `Y` from `formula` with `data`
model <- stats::model.frame(formula, data)
X <- as.matrix(model[,-1, drop=FALSE])
Y <- as.matrix(model[, 1, drop=FALSE])
# pass extracted data on to [cve.call()]
dr <- cve.call(X, Y, method = method, ...)
# overwrite `call` property from [cve.call()]
dr$call <- match.call()
return(dr)
}
#' @param nObs as describet in the Paper.
#' @rdname cve
#' @export
cve.call <- function(X, Y, method = "simple", nObs = nrow(X)^.5, k, ...) {
# TODO: replace default value of `k` by `max.dim` when fast enough
if (missing(k)) {
stop("TODO: parameter `k` (rank(B)) is required, replace by `max.dim`.")
}
# parameter checking
if (!(is.matrix(X) && is.matrix(Y))) {
stop('X and Y should be matrices.')
}
if (nrow(X) != nrow(Y)) {
stop('Rows of X and Y are not compatible.')
}
if (ncol(X) < 2) {
stop('X is one dimensional, no need for dimension reduction.')
}
if (ncol(Y) > 1) {
stop('Only one dimensional responces Y are supported.')
}
# call C++ CVE implementation
# dr ... Dimension Reduction
dr <- cve_cpp(X, Y, tolower(method), k = k, nObs = nObs, ...)
# augment result information
dr$method <- method
dr$call <- match.call()
class(dr) <- "cve"
return(dr)
}
# TODO: write summary
# summary.cve <- function() {
# # code #
# }
#' Ploting helper for objects of class \code{cve}.
#'
#' @param x Object of class \code{cve} (result of [cve()]).
#' @param content Specifies what to plot:
#' \itemize{
#' \item "history" Plots the loss history from stiefel optimization
#' (default).
#' \item ... TODO: add (if there are any)
#' }
#' @param ... Pass through parameters to [plot()] and [lines()]
#'
#' @usage ## S3 method for class 'cve'
#' plot(x, content = "history", ...)
#' @seealso see \code{\link{par}} for graphical parameters to pass through
#' as well as \code{\link{plot}} for standard plot utility.
#' @importFrom graphics plot lines points
#' @method plot cve
#' @export
plot.cve <- function(x, ...) {
H <- x$history
H_1 <- H[H[, 1] > 0, 1]
defaults <- list(
main = "History",
xlab = "Iterations i",
ylab = expression(loss == L[n](V^{(i)})),
xlim = c(1, nrow(H)),
ylim = c(0, max(H)),
type = "l"
)
call.plot <- match.call()
keys <- names(defaults)
keys <- keys[match(keys, names(call.plot)[-1], nomatch = 0) == 0]
for (key in keys) {
call.plot[[key]] <- defaults[[key]]
}
call.plot[[1L]] <- quote(plot)
call.plot$x <- quote(1:length(H_1))
call.plot$y <- quote(H_1)
eval(call.plot)
if (ncol(H) > 1) {
for (i in 2:ncol(H)) {
H_i <- H[H[, i] > 0, i]
lines(1:length(H_i), H_i)
}
}
x.ends <- apply(H, 2, function(h) { length(h[h > 0]) })
y.ends <- apply(H, 2, function(h) { min(h[h > 0]) })
points(x.ends, y.ends)
}