266 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			266 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Conditional Variance Estimator (CVE)
 | 
						||
#'
 | 
						||
#' Conditional Variance Estimator for Sufficient Dimension
 | 
						||
#'  Reduction
 | 
						||
#'
 | 
						||
#' TODO: And some details
 | 
						||
#'
 | 
						||
#'
 | 
						||
#' @references Fertl Likas, Bura Efstathia. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
 | 
						||
#'
 | 
						||
#' @docType package
 | 
						||
#' @author Loki
 | 
						||
#' @useDynLib CVE, .registration = TRUE
 | 
						||
"_PACKAGE"
 | 
						||
 | 
						||
#' Implementation of the CVE method.
 | 
						||
#'
 | 
						||
#' Conditional Variance Estimator (CVE) is a novel sufficient dimension
 | 
						||
#' reduction (SDR) method assuming a model
 | 
						||
#' \deqn{Y \sim g(B'X) + \epsilon}{Y ~ g(B'X) + epsilon}
 | 
						||
#' where B'X is a lower dimensional projection of the predictors.
 | 
						||
#'
 | 
						||
#' @param formula Formel for the regression model defining `X`, `Y`.
 | 
						||
#'  See: \code{\link{formula}}.
 | 
						||
#' @param data data.frame holding data for formula.
 | 
						||
#' @param method The different only differe in the used optimization.
 | 
						||
#'  All of them are Gradient based optimization on a Stiefel manifold.
 | 
						||
#' \itemize{
 | 
						||
#'      \item "simple" Simple reduction of stepsize.
 | 
						||
#'      \item "sgd" stocastic gradient decent.
 | 
						||
#'      \item TODO: further
 | 
						||
#' }
 | 
						||
#' @param ... Further parameters depending on the used method.
 | 
						||
#' @examples
 | 
						||
#' library(CVE)
 | 
						||
#'
 | 
						||
#' # sample dataset
 | 
						||
#' ds <- dataset("M5")
 | 
						||
#'
 | 
						||
#' # call ´cve´ with default method (aka "simple")
 | 
						||
#' dr.simple <- cve(ds$Y ~ ds$X, k = ncol(ds$B))
 | 
						||
#' # plot optimization history (loss via iteration)
 | 
						||
#' plot(dr.simple, main = "CVE M5 simple")
 | 
						||
#'
 | 
						||
#' # call ´cve´ with method "linesearch" using ´data.frame´ as data.
 | 
						||
#' data <- data.frame(Y = ds$Y, X = ds$X)
 | 
						||
#' # Note: ´Y, X´ are NOT defined, they are extracted from ´data´.
 | 
						||
#' dr.linesearch <- cve(Y ~ ., data, method = "linesearch", k = ncol(ds$B))
 | 
						||
#' plot(dr.linesearch, main = "CVE M5 linesearch")
 | 
						||
#'
 | 
						||
#' @references Fertl L., Bura E. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
 | 
						||
#'
 | 
						||
#' @seealso \code{\link{formula}}. For a complete parameters list (dependent on
 | 
						||
#'  the method) see \code{\link{cve_simple}}, \code{\link{cve_sgd}}
 | 
						||
#' @import stats
 | 
						||
#' @importFrom stats model.frame
 | 
						||
#' @export
 | 
						||
cve <- function(formula, data, method = "simple", max.dim = 10L, ...) {
 | 
						||
    # check for type of `data` if supplied and set default
 | 
						||
    if (missing(data)) {
 | 
						||
        data <- environment(formula)
 | 
						||
    } else if (!is.data.frame(data)) {
 | 
						||
        stop("Parameter 'data' must be a 'data.frame' or missing.")
 | 
						||
    }
 | 
						||
 | 
						||
    # extract `X`, `Y` from `formula` with `data`
 | 
						||
    model <- stats::model.frame(formula, data)
 | 
						||
    X <- as.matrix(model[ ,-1L, drop = FALSE])
 | 
						||
    Y <- as.double(model[ , 1L])
 | 
						||
 | 
						||
    # pass extracted data on to [cve.call()]
 | 
						||
    dr <- cve.call(X, Y, method = method, max.dim = max.dim, ...)
 | 
						||
 | 
						||
    # overwrite `call` property from [cve.call()]
 | 
						||
    dr$call <- match.call()
 | 
						||
    return(dr)
 | 
						||
}
 | 
						||
 | 
						||
#' @param nObs as describet in the Paper.
 | 
						||
#' @param X Data
 | 
						||
#' @param Y Responces
 | 
						||
#' @param nObs Like in the paper.
 | 
						||
#' @param k guess for SDR dimension.
 | 
						||
#' @param ... Method specific parameters.
 | 
						||
#' @rdname cve
 | 
						||
#' @export
 | 
						||
cve.call <- function(X, Y, method = "simple",
 | 
						||
                     nObs = sqrt(nrow(X)), h = NULL,
 | 
						||
                     min.dim = 1L, max.dim = 10L, k = NULL,
 | 
						||
                     tau = 1.0, tol = 1e-3,
 | 
						||
                     epochs = 50L, attempts = 10L,
 | 
						||
                     logger = NULL) {
 | 
						||
 | 
						||
    # parameter checking
 | 
						||
    if (!(is.matrix(X) && is.numeric(X))) {
 | 
						||
        stop("Parameter 'X' should be a numeric matrices.")
 | 
						||
    }
 | 
						||
    if (!is.numeric(Y)) {
 | 
						||
        stop("Parameter 'Y' must be numeric.")
 | 
						||
    }
 | 
						||
    if (is.matrix(Y) || !is.double(Y)) {
 | 
						||
        Y <- as.double(Y)
 | 
						||
    }
 | 
						||
    if (nrow(X) != length(Y)) {
 | 
						||
        stop("Rows of 'X' and 'Y' elements are not compatible.")
 | 
						||
    }
 | 
						||
    if (ncol(X) < 2) {
 | 
						||
        stop("'X' is one dimensional, no need for dimension reduction.")
 | 
						||
    }
 | 
						||
 | 
						||
    if (missing(k) || is.null(k)) {
 | 
						||
        min.dim <- as.integer(min.dim)
 | 
						||
        max.dim <- as.integer(min(max.dim, ncol(X) - 1L))
 | 
						||
    } else {
 | 
						||
        min.dim <- as.integer(k)
 | 
						||
        max.dim <- as.integer(k)
 | 
						||
    }
 | 
						||
    if (min.dim > max.dim) {
 | 
						||
        stop("'min.dim' bigger 'max.dim'.")
 | 
						||
    }
 | 
						||
    if (max.dim >= ncol(X)) {
 | 
						||
        stop("'max.dim' (or 'k') must be smaller than 'ncol(X)'.")
 | 
						||
    }
 | 
						||
 | 
						||
    if (is.function(h)) {
 | 
						||
        estimate.bandwidth <- h
 | 
						||
        h <- NULL
 | 
						||
    }
 | 
						||
 | 
						||
    if (!is.numeric(tau) || length(tau) > 1L || tau <= 0.0) {
 | 
						||
        stop("Initial step-width 'tau' must be positive number.")
 | 
						||
    } else {
 | 
						||
        tau <- as.double(tau)
 | 
						||
    }
 | 
						||
    if (!is.numeric(tol) || length(tol) > 1L || tol < 0.0) {
 | 
						||
        stop("Break condition tolerance 'tol' must be not negative number.")
 | 
						||
    } else {
 | 
						||
        tol <- as.double(tol)
 | 
						||
    }
 | 
						||
 | 
						||
    if (!is.numeric(epochs) || length(epochs) > 1L) {
 | 
						||
        stop("Parameter 'epochs' must be positive integer.")
 | 
						||
    } else if (!is.integer(epochs)) {
 | 
						||
        epochs <- as.integer(epochs)
 | 
						||
    }
 | 
						||
    if (epochs < 1L) {
 | 
						||
        stop("Parameter 'epochs' must be at least 1L.")
 | 
						||
    }
 | 
						||
    if (!is.numeric(attempts) || length(attempts) > 1L) {
 | 
						||
        stop("Parameter 'attempts' must be positive integer.")
 | 
						||
    } else if (!is.integer(attempts)) {
 | 
						||
        attempts <- as.integer(attempts)
 | 
						||
    }
 | 
						||
    if (attempts < 1L) {
 | 
						||
        stop("Parameter 'attempts' must be at least 1L.")
 | 
						||
    }
 | 
						||
 | 
						||
    if (is.function(logger)) {
 | 
						||
        loggerEnv <- environment(logger)
 | 
						||
    } else {
 | 
						||
        loggerEnv <- NULL
 | 
						||
    }
 | 
						||
 | 
						||
    # Call specified method.
 | 
						||
    method <- tolower(method)
 | 
						||
    call <- match.call()
 | 
						||
    dr <- list()
 | 
						||
    for (k in min.dim:max.dim) {
 | 
						||
 | 
						||
        if (missing(h) || is.null(h)) {
 | 
						||
            h <- estimate.bandwidth(X, k, nObs)
 | 
						||
        } else if (is.numeric(h) && h > 0.0) {
 | 
						||
            h <- as.double(h)
 | 
						||
        } else {
 | 
						||
            stop("Bandwidth 'h' must be positive numeric.")
 | 
						||
        }
 | 
						||
 | 
						||
        if (method == 'simple') {
 | 
						||
            dr.k <- .Call('cve_simple', PACKAGE = 'CVE',
 | 
						||
                          X, Y, k, h,
 | 
						||
                          tau, tol,
 | 
						||
                          epochs, attempts,
 | 
						||
                          logger, loggerEnv)
 | 
						||
        #     dr.k <- cve_simple(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'linesearch') {
 | 
						||
        #     dr.k <- cve_linesearch(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'rcg') {
 | 
						||
        #     dr.k <- cve_rcg(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'momentum') {
 | 
						||
        #     dr.k <- cve_momentum(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'rmsprob') {
 | 
						||
        #     dr.k <- cve_rmsprob(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'sgdrmsprob') {
 | 
						||
        #     dr.k <- cve_sgdrmsprob(X, Y, k, nObs = nObs, ...)
 | 
						||
        # } else if (method == 'sgd') {
 | 
						||
        #     dr.k <- cve_sgd(X, Y, k, nObs = nObs, ...)
 | 
						||
        } else {
 | 
						||
            stop('Got unknown method.')
 | 
						||
        }
 | 
						||
        dr.k$B <- null(dr.k$V)
 | 
						||
        dr.k$loss <- mean(dr.k$L)
 | 
						||
        dr.k$h <- h
 | 
						||
        dr.k$k <- k
 | 
						||
        class(dr.k) <- "cve.k"
 | 
						||
        dr[[k]] <- dr.k
 | 
						||
    }
 | 
						||
 | 
						||
    # augment result information
 | 
						||
    dr$method <- method
 | 
						||
    dr$call <- call
 | 
						||
    class(dr) <- "cve"
 | 
						||
    return(dr)
 | 
						||
}
 | 
						||
 | 
						||
#' Ploting helper for objects of class \code{cve}.
 | 
						||
#'
 | 
						||
#' @param x Object of class \code{cve} (result of [cve()]).
 | 
						||
#' @param content Specifies what to plot:
 | 
						||
#' \itemize{
 | 
						||
#'      \item "history" Plots the loss history from stiefel optimization
 | 
						||
#'          (default).
 | 
						||
#'      \item ... TODO: add (if there are any)
 | 
						||
#' }
 | 
						||
#' @param ... Pass through parameters to [plot()] and [lines()]
 | 
						||
#'
 | 
						||
#' @usage ## S3 method for class 'cve'
 | 
						||
#' plot(x, content = "history", ...)
 | 
						||
#' @seealso see \code{\link{par}} for graphical parameters to pass through
 | 
						||
#'      as well as \code{\link{plot}} for standard plot utility.
 | 
						||
#' @importFrom graphics plot lines points
 | 
						||
#' @method plot cve
 | 
						||
#' @export
 | 
						||
plot.cve <- function(x, ...) {
 | 
						||
    L <- c()
 | 
						||
    k <- c()
 | 
						||
    for (dr.k in x) {
 | 
						||
        if (class(dr.k) == 'cve.k') {
 | 
						||
            k <- c(k, paste0(dr.k$k))
 | 
						||
            L <- c(L, dr.k$L)
 | 
						||
        }
 | 
						||
    }
 | 
						||
    L <- matrix(L, ncol = length(k))
 | 
						||
    boxplot(L, main = "Loss ...",
 | 
						||
            xlab = "SDR dimension k",
 | 
						||
            ylab = expression(L(V, X[i])),
 | 
						||
            names = k)
 | 
						||
}
 | 
						||
 | 
						||
#' Prints a summary of a \code{cve} result.
 | 
						||
#' @param object Instance of 'cve' as return of \code{cve}.
 | 
						||
#' @method summary cve
 | 
						||
#' @export
 | 
						||
summary.cve <- function(object, ...) {
 | 
						||
    cat('Summary of CVE result - Method: "', object$method, '"\n',
 | 
						||
        '\n',
 | 
						||
        'Dataset size:   ', nrow(object$X), '\n',
 | 
						||
        'Data Dimension: ', ncol(object$X), '\n',
 | 
						||
        'SDR Dimension:  ', object$k, '\n',
 | 
						||
        'loss:           ', object$loss, '\n',
 | 
						||
        '\n',
 | 
						||
        'Called via:\n',
 | 
						||
        '    ',
 | 
						||
        sep='')
 | 
						||
    print(object$call)
 | 
						||
}
 |