2
0
Fork 0
CVE/CVE_R/R/CVE.R

223 lines
6.6 KiB
R
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#' Conditional Variance Estimator (CVE)
#'
#' Conditional Variance Estimator for Sufficient Dimension
#' Reduction
#'
#' TODO: And some details
#'
#'
#' @references Fertl Likas, Bura Efstathia. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
#'
#' @docType package
#' @author Loki
"_PACKAGE"
#' Implementation of the CVE method.
#'
#' Conditional Variance Estimator (CVE) is a novel sufficient dimension
#' reduction (SDR) method assuming a model
#' \deqn{Y \sim g(B'X) + \epsilon}{Y ~ g(B'X) + epsilon}
#' where B'X is a lower dimensional projection of the predictors.
#'
#' @param formula Formel for the regression model defining `X`, `Y`.
#' See: \code{\link{formula}}.
#' @param data data.frame holding data for formula.
#' @param method The different only differe in the used optimization.
#' All of them are Gradient based optimization on a Stiefel manifold.
#' \itemize{
#' \item "simple" Simple reduction of stepsize.
#' \item "sgd" stocastic gradient decent.
#' \item TODO: further
#' }
#' @param ... Further parameters depending on the used method.
#' @examples
#' library(CVE)
#'
#' # sample dataset
#' ds <- dataset("M5")
#'
#' # call ´cve´ with default method (aka "simple")
#' dr.simple <- cve(ds$Y ~ ds$X, k = ncol(ds$B))
#' # plot optimization history (loss via iteration)
#' plot(dr.simple, main = "CVE M5 simple")
#'
#' # call ´cve´ with method "linesearch" using ´data.frame´ as data.
#' data <- data.frame(Y = ds$Y, X = ds$X)
#' # Note: ´Y, X´ are NOT defined, they are extracted from ´data´.
#' dr.linesearch <- cve(Y ~ ., data, method = "linesearch", k = ncol(ds$B))
#' plot(dr.linesearch, main = "CVE M5 linesearch")
#'
#' @references Fertl L., Bura E. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
#'
#' @seealso \code{\link{formula}}. For a complete parameters list (dependent on
#' the method) see \code{\link{cve_simple}}, \code{\link{cve_sgd}}
#' @import stats
#' @importFrom stats model.frame
#' @export
cve <- function(formula, data, method = "simple", max.dim = 10, ...) {
# check for type of `data` if supplied and set default
if (missing(data)) {
data <- environment(formula)
} else if (!is.data.frame(data)) {
stop('Parameter `data` must be a `data.frame` or missing.')
}
# extract `X`, `Y` from `formula` with `data`
model <- stats::model.frame(formula, data)
X <- as.matrix(model[,-1, drop = FALSE])
Y <- as.matrix(model[, 1, drop = FALSE])
# pass extracted data on to [cve.call()]
dr <- cve.call(X, Y, method = method, ...)
# overwrite `call` property from [cve.call()]
dr$call <- match.call()
return(dr)
}
#' @param nObs as describet in the Paper.
#' @param X Data
#' @param Y Responces
#' @param nObs Like in the paper.
#' @param k guess for SDR dimension.
#' @param ... Method specific parameters.
#' @rdname cve
#' @export
cve.call <- function(X, Y, method = "simple", nObs = nrow(X)^.5,
min.dim = 1, max.dim = 10, k, ...) {
# parameter checking
if (!is.matrix(X)) {
stop('X should be a matrices.')
}
if (is.matrix(Y)) {
Y <- as.vector(Y)
}
if (nrow(X) != length(Y)) {
stop('Rows of X and number of Y elements are not compatible.')
}
if (ncol(X) < 2) {
stop('X is one dimensional, no need for dimension reduction.')
}
if (!missing(k)) {
min.dim <- as.integer(k)
max.dim <- as.integer(k)
} else {
min.dim <- as.integer(min.dim)
max.dim <- as.integer(min(max.dim, ncol(X) - 1L))
}
if (min.dim > max.dim) {
stop('`min.dim` bigger `max.dim`.')
}
if (max.dim >= ncol(X)) {
stop('`max.dim` must be smaller than `ncol(X)`.')
}
# Call specified method.
method <- tolower(method)
call <- match.call()
dr <- list()
for (k in min.dim:max.dim) {
if (method == 'simple') {
dr.k <- cve_simple(X, Y, k, nObs = nObs, ...)
} else if (method == 'linesearch') {
dr.k <- cve_linesearch(X, Y, k, nObs = nObs, ...)
} else if (method == 'sgd') {
dr.k <- cve_sgd(X, Y, k, nObs = nObs, ...)
} else {
stop('Got unknown method.')
}
dr.k$k <- k
class(dr.k) <- "cve.k"
dr[[k]] <- dr.k
}
# augment result information
dr$method <- method
dr$call <- call
class(dr) <- "cve"
return(dr)
}
# TODO: write summary
# summary.cve <- function() {
# # code #
# }
#' Ploting helper for objects of class \code{cve}.
#'
#' @param x Object of class \code{cve} (result of [cve()]).
#' @param content Specifies what to plot:
#' \itemize{
#' \item "history" Plots the loss history from stiefel optimization
#' (default).
#' \item ... TODO: add (if there are any)
#' }
#' @param ... Pass through parameters to [plot()] and [lines()]
#'
#' @usage ## S3 method for class 'cve'
#' plot(x, content = "history", ...)
#' @seealso see \code{\link{par}} for graphical parameters to pass through
#' as well as \code{\link{plot}} for standard plot utility.
#' @importFrom graphics plot lines points
#' @method plot cve
#' @export
plot.cve <- function(x, ...) {
# H <- x$history
# H_1 <- H[!is.na(H[, 1]), 1]
# defaults <- list(
# main = "History",
# xlab = "Iterations i",
# ylab = expression(loss == L[n](V^{(i)})),
# xlim = c(1, nrow(H)),
# ylim = c(0, max(H)),
# type = "l"
# )
# call.plot <- match.call()
# keys <- names(defaults)
# keys <- keys[match(keys, names(call.plot)[-1], nomatch = 0) == 0]
# for (key in keys) {
# call.plot[[key]] <- defaults[[key]]
# }
# call.plot[[1L]] <- quote(plot)
# call.plot$x <- quote(1:length(H_1))
# call.plot$y <- quote(H_1)
# eval(call.plot)
# if (ncol(H) > 1) {
# for (i in 2:ncol(H)) {
# H_i <- H[H[, i] > 0, i]
# lines(1:length(H_i), H_i)
# }
# }
# x.ends <- apply(H, 2, function(h) { length(h[!is.na(h)]) })
# y.ends <- apply(H, 2, function(h) { tail(h[!is.na(h)], n=1) })
# points(x.ends, y.ends)
}
#' Prints a summary of a \code{cve} result.
#' @param object Instance of 'cve' as return of \code{cve}.
#' @method summary cve
#' @export
summary.cve <- function(object, ...) {
cat('Summary of CVE result - Method: "', object$method, '"\n',
'\n',
'Dataset size: ', nrow(object$X), '\n',
'Data Dimension: ', ncol(object$X), '\n',
'SDR Dimension: ', object$k, '\n',
'loss: ', object$loss, '\n',
'\n',
'Called via:\n',
' ',
sep='')
print(object$call)
}