#' Conditional Variance Estimator (CVE) #' #' Conditional Variance Estimator for Sufficient Dimension #' Reduction #' #' TODO: And some details #' #' #' @references Fertl Likas, Bura Efstathia. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019 #' #' @docType package #' @author Loki #' @useDynLib CVE, .registration = TRUE "_PACKAGE" #' Implementation of the CVE method. #' #' Conditional Variance Estimator (CVE) is a novel sufficient dimension #' reduction (SDR) method assuming a model #' \deqn{Y \sim g(B'X) + \epsilon}{Y ~ g(B'X) + epsilon} #' where B'X is a lower dimensional projection of the predictors. #' #' @param formula Formel for the regression model defining `X`, `Y`. #' See: \code{\link{formula}}. #' @param data data.frame holding data for formula. #' @param method The different only differe in the used optimization. #' All of them are Gradient based optimization on a Stiefel manifold. #' \itemize{ #' \item "simple" Simple reduction of stepsize. #' \item "sgd" stocastic gradient decent. #' \item TODO: further #' } #' @param ... Further parameters depending on the used method. #' @examples #' library(CVE) #' #' # sample dataset #' ds <- dataset("M5") #' #' # call ´cve´ with default method (aka "simple") #' dr.simple <- cve(ds$Y ~ ds$X, k = ncol(ds$B)) #' # plot optimization history (loss via iteration) #' plot(dr.simple, main = "CVE M5 simple") #' #' # call ´cve´ with method "linesearch" using ´data.frame´ as data. #' data <- data.frame(Y = ds$Y, X = ds$X) #' # Note: ´Y, X´ are NOT defined, they are extracted from ´data´. #' dr.linesearch <- cve(Y ~ ., data, method = "linesearch", k = ncol(ds$B)) #' plot(dr.linesearch, main = "CVE M5 linesearch") #' #' @references Fertl L., Bura E. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019 #' #' @seealso \code{\link{formula}}. For a complete parameters list (dependent on #' the method) see \code{\link{cve_simple}}, \code{\link{cve_sgd}} #' @import stats #' @importFrom stats model.frame #' @export cve <- function(formula, data, method = "simple", max.dim = 10L, ...) { # check for type of `data` if supplied and set default if (missing(data)) { data <- environment(formula) } else if (!is.data.frame(data)) { stop("Parameter 'data' must be a 'data.frame' or missing.") } # extract `X`, `Y` from `formula` with `data` model <- stats::model.frame(formula, data) X <- as.matrix(model[ ,-1L, drop = FALSE]) Y <- as.double(model[ , 1L]) # pass extracted data on to [cve.call()] dr <- cve.call(X, Y, method = method, max.dim = max.dim, ...) # overwrite `call` property from [cve.call()] dr$call <- match.call() return(dr) } #' @param nObs as describet in the Paper. #' @param X Data #' @param Y Responces #' @param nObs Like in the paper. #' @param k guess for SDR dimension. #' @param ... Method specific parameters. #' @rdname cve #' @export cve.call <- function(X, Y, method = "simple", nObs = sqrt(nrow(X)), h = NULL, min.dim = 1L, max.dim = 10L, k = NULL, tau = 1.0, tol = 1e-3, epochs = 50L, attempts = 10L, logger = NULL) { # parameter checking if (!(is.matrix(X) && is.numeric(X))) { stop("Parameter 'X' should be a numeric matrices.") } if (!is.numeric(Y)) { stop("Parameter 'Y' must be numeric.") } if (is.matrix(Y) || !is.double(Y)) { Y <- as.double(Y) } if (nrow(X) != length(Y)) { stop("Rows of 'X' and 'Y' elements are not compatible.") } if (ncol(X) < 2) { stop("'X' is one dimensional, no need for dimension reduction.") } if (missing(k) || is.null(k)) { min.dim <- as.integer(min.dim) max.dim <- as.integer(min(max.dim, ncol(X) - 1L)) } else { min.dim <- as.integer(k) max.dim <- as.integer(k) } if (min.dim > max.dim) { stop("'min.dim' bigger 'max.dim'.") } if (max.dim >= ncol(X)) { stop("'max.dim' (or 'k') must be smaller than 'ncol(X)'.") } if (is.function(h)) { estimate.bandwidth <- h h <- NULL } if (!is.numeric(tau) || length(tau) > 1L || tau <= 0.0) { stop("Initial step-width 'tau' must be positive number.") } else { tau <- as.double(tau) } if (!is.numeric(tol) || length(tol) > 1L || tol < 0.0) { stop("Break condition tolerance 'tol' must be not negative number.") } else { tol <- as.double(tol) } if (!is.numeric(epochs) || length(epochs) > 1L) { stop("Parameter 'epochs' must be positive integer.") } else if (!is.integer(epochs)) { epochs <- as.integer(epochs) } if (epochs < 1L) { stop("Parameter 'epochs' must be at least 1L.") } if (!is.numeric(attempts) || length(attempts) > 1L) { stop("Parameter 'attempts' must be positive integer.") } else if (!is.integer(attempts)) { attempts <- as.integer(attempts) } if (attempts < 1L) { stop("Parameter 'attempts' must be at least 1L.") } if (is.function(logger)) { loggerEnv <- environment(logger) } else { loggerEnv <- NULL } # Call specified method. method <- tolower(method) call <- match.call() dr <- list() for (k in min.dim:max.dim) { if (missing(h) || is.null(h)) { h <- estimate.bandwidth(X, k, nObs) } else if (is.numeric(h) && h > 0.0) { h <- as.double(h) } else { stop("Bandwidth 'h' must be positive numeric.") } if (method == 'simple') { dr.k <- .Call('cve_simple', PACKAGE = 'CVE', X, Y, k, h, tau, tol, epochs, attempts, logger, loggerEnv) # dr.k <- cve_simple(X, Y, k, nObs = nObs, ...) # } else if (method == 'linesearch') { # dr.k <- cve_linesearch(X, Y, k, nObs = nObs, ...) # } else if (method == 'rcg') { # dr.k <- cve_rcg(X, Y, k, nObs = nObs, ...) # } else if (method == 'momentum') { # dr.k <- cve_momentum(X, Y, k, nObs = nObs, ...) # } else if (method == 'rmsprob') { # dr.k <- cve_rmsprob(X, Y, k, nObs = nObs, ...) # } else if (method == 'sgdrmsprob') { # dr.k <- cve_sgdrmsprob(X, Y, k, nObs = nObs, ...) # } else if (method == 'sgd') { # dr.k <- cve_sgd(X, Y, k, nObs = nObs, ...) } else { stop('Got unknown method.') } dr.k$B <- null(dr.k$V) dr.k$loss <- mean(dr.k$L) dr.k$h <- h dr.k$k <- k class(dr.k) <- "cve.k" dr[[k]] <- dr.k } # augment result information dr$method <- method dr$call <- call class(dr) <- "cve" return(dr) } #' Ploting helper for objects of class \code{cve}. #' #' @param x Object of class \code{cve} (result of [cve()]). #' @param content Specifies what to plot: #' \itemize{ #' \item "history" Plots the loss history from stiefel optimization #' (default). #' \item ... TODO: add (if there are any) #' } #' @param ... Pass through parameters to [plot()] and [lines()] #' #' @usage ## S3 method for class 'cve' #' plot(x, content = "history", ...) #' @seealso see \code{\link{par}} for graphical parameters to pass through #' as well as \code{\link{plot}} for standard plot utility. #' @importFrom graphics plot lines points #' @method plot cve #' @export plot.cve <- function(x, ...) { L <- c() k <- c() for (dr.k in x) { if (class(dr.k) == 'cve.k') { k <- c(k, paste0(dr.k$k)) L <- c(L, dr.k$L) } } L <- matrix(L, ncol = length(k)) boxplot(L, main = "Loss ...", xlab = "SDR dimension k", ylab = expression(L(V, X[i])), names = k) } #' Prints a summary of a \code{cve} result. #' @param object Instance of 'cve' as return of \code{cve}. #' @method summary cve #' @export summary.cve <- function(object, ...) { cat('Summary of CVE result - Method: "', object$method, '"\n', '\n', 'Dataset size: ', nrow(object$X), '\n', 'Data Dimension: ', ncol(object$X), '\n', 'SDR Dimension: ', object$k, '\n', 'loss: ', object$loss, '\n', '\n', 'Called via:\n', ' ', sep='') print(object$call) }