2019-08-30 19:16:52 +00:00
|
|
|
|
#' Implementation of the CVE method.
|
|
|
|
|
#'
|
|
|
|
|
#' Conditional Variance Estimator (CVE) is a novel sufficient dimension
|
|
|
|
|
#' reduction (SDR) method assuming a model
|
|
|
|
|
#' \deqn{Y \sim g(B'X) + \epsilon}{Y ~ g(B'X) + epsilon}
|
|
|
|
|
#' where B'X is a lower dimensional projection of the predictors.
|
|
|
|
|
#'
|
|
|
|
|
#' @param formula Formel for the regression model defining `X`, `Y`.
|
|
|
|
|
#' See: \code{\link{formula}}.
|
|
|
|
|
#' @param data data.frame holding data for formula.
|
|
|
|
|
#' @param method The different only differe in the used optimization.
|
|
|
|
|
#' All of them are Gradient based optimization on a Stiefel manifold.
|
|
|
|
|
#' \itemize{
|
|
|
|
|
#' \item "simple" Simple reduction of stepsize.
|
|
|
|
|
#' \item "sgd" stocastic gradient decent.
|
|
|
|
|
#' \item TODO: further
|
|
|
|
|
#' }
|
|
|
|
|
#' @param ... Further parameters depending on the used method.
|
|
|
|
|
#' @examples
|
|
|
|
|
#' library(CVE)
|
|
|
|
|
#'
|
|
|
|
|
#' # sample dataset
|
|
|
|
|
#' ds <- dataset("M5")
|
|
|
|
|
#'
|
|
|
|
|
#' # call ´cve´ with default method (aka "simple")
|
|
|
|
|
#' dr.simple <- cve(ds$Y ~ ds$X, k = ncol(ds$B))
|
|
|
|
|
#' # plot optimization history (loss via iteration)
|
|
|
|
|
#' plot(dr.simple, main = "CVE M5 simple")
|
|
|
|
|
#'
|
|
|
|
|
#' # call ´cve´ with method "linesearch" using ´data.frame´ as data.
|
|
|
|
|
#' data <- data.frame(Y = ds$Y, X = ds$X)
|
|
|
|
|
#' # Note: ´Y, X´ are NOT defined, they are extracted from ´data´.
|
|
|
|
|
#' dr.linesearch <- cve(Y ~ ., data, method = "linesearch", k = ncol(ds$B))
|
|
|
|
|
#' plot(dr.linesearch, main = "CVE M5 linesearch")
|
|
|
|
|
#'
|
|
|
|
|
#' @references Fertl L., Bura E. Conditional Variance Estimation for Sufficient Dimension Reduction, 2019
|
|
|
|
|
#'
|
|
|
|
|
#' @seealso \code{\link{formula}}. For a complete parameters list (dependent on
|
|
|
|
|
#' the method) see \code{\link{cve_simple}}, \code{\link{cve_sgd}}
|
|
|
|
|
#' @import stats
|
|
|
|
|
#' @importFrom stats model.frame
|
|
|
|
|
#' @export
|
|
|
|
|
cve <- function(formula, data, method = "simple", ...) {
|
|
|
|
|
# check for type of `data` if supplied and set default
|
|
|
|
|
if (missing(data)) {
|
|
|
|
|
data <- environment(formula)
|
|
|
|
|
} else if (!is.data.frame(data)) {
|
|
|
|
|
stop('Parameter `data` must be a `data.frame` or missing.')
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# extract `X`, `Y` from `formula` with `data`
|
|
|
|
|
model <- stats::model.frame(formula, data)
|
|
|
|
|
X <- as.matrix(model[,-1, drop = FALSE])
|
|
|
|
|
Y <- as.matrix(model[, 1, drop = FALSE])
|
|
|
|
|
|
|
|
|
|
# pass extracted data on to [cve.call()]
|
|
|
|
|
dr <- cve.call(X, Y, method = method, ...)
|
|
|
|
|
|
|
|
|
|
# overwrite `call` property from [cve.call()]
|
|
|
|
|
dr$call <- match.call()
|
|
|
|
|
return(dr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#' @param nObs as describet in the Paper.
|
|
|
|
|
#' @param X Data
|
|
|
|
|
#' @param Y Responces
|
|
|
|
|
#' @param nObs Like in the paper.
|
|
|
|
|
#' @param k guess for SDR dimension.
|
|
|
|
|
#' @param ... Method specific parameters.
|
|
|
|
|
#' @rdname cve
|
|
|
|
|
#' @export
|
|
|
|
|
cve.call <- function(X, Y, method = "simple", nObs = nrow(X)^.5, k, ...) {
|
|
|
|
|
|
|
|
|
|
# TODO: replace default value of `k` by `max.dim` when fast enough
|
|
|
|
|
if (missing(k)) {
|
|
|
|
|
stop("TODO: parameter `k` (rank(B)) is required, replace by `max.dim`.")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# parameter checking
|
|
|
|
|
if (!is.matrix(X)) {
|
|
|
|
|
stop('X should be a matrices.')
|
|
|
|
|
}
|
|
|
|
|
if (is.matrix(Y)) {
|
|
|
|
|
Y <- as.vector(Y)
|
|
|
|
|
}
|
|
|
|
|
if (nrow(X) != length(Y)) {
|
|
|
|
|
stop('Rows of X and number of Y elements are not compatible.')
|
|
|
|
|
}
|
|
|
|
|
if (ncol(X) < 2) {
|
|
|
|
|
stop('X is one dimensional, no need for dimension reduction.')
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Call specified method.
|
|
|
|
|
method <- tolower(method)
|
|
|
|
|
if (method == 'simple') {
|
|
|
|
|
dr <- cve_simple(X, Y, k, nObs = nObs, ...)
|
2019-09-02 19:07:56 +00:00
|
|
|
|
} else if (method == 'linesearch') {
|
|
|
|
|
dr <- cve_linesearch(X, Y, k, nObs = nObs, ...)
|
2019-08-30 19:16:52 +00:00
|
|
|
|
} else if (method == 'sgd') {
|
|
|
|
|
dr <- cve_sgd(X, Y, k, nObs = nObs, ...)
|
|
|
|
|
} else {
|
|
|
|
|
stop('Got unknown method.')
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# augment result information
|
|
|
|
|
dr$method <- method
|
|
|
|
|
dr$call <- match.call()
|
|
|
|
|
class(dr) <- "cve"
|
|
|
|
|
return(dr)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# TODO: write summary
|
|
|
|
|
# summary.cve <- function() {
|
|
|
|
|
# # code #
|
|
|
|
|
# }
|
|
|
|
|
|
|
|
|
|
#' Ploting helper for objects of class \code{cve}.
|
|
|
|
|
#'
|
|
|
|
|
#' @param x Object of class \code{cve} (result of [cve()]).
|
|
|
|
|
#' @param content Specifies what to plot:
|
|
|
|
|
#' \itemize{
|
|
|
|
|
#' \item "history" Plots the loss history from stiefel optimization
|
|
|
|
|
#' (default).
|
|
|
|
|
#' \item ... TODO: add (if there are any)
|
|
|
|
|
#' }
|
|
|
|
|
#' @param ... Pass through parameters to [plot()] and [lines()]
|
|
|
|
|
#'
|
|
|
|
|
#' @usage ## S3 method for class 'cve'
|
|
|
|
|
#' plot(x, content = "history", ...)
|
|
|
|
|
#' @seealso see \code{\link{par}} for graphical parameters to pass through
|
|
|
|
|
#' as well as \code{\link{plot}} for standard plot utility.
|
|
|
|
|
#' @importFrom graphics plot lines points
|
|
|
|
|
#' @method plot cve
|
|
|
|
|
#' @export
|
|
|
|
|
plot.cve <- function(x, ...) {
|
|
|
|
|
|
|
|
|
|
H <- x$history
|
|
|
|
|
H_1 <- H[!is.na(H[, 1]), 1]
|
|
|
|
|
|
|
|
|
|
defaults <- list(
|
|
|
|
|
main = "History",
|
|
|
|
|
xlab = "Iterations i",
|
|
|
|
|
ylab = expression(loss == L[n](V^{(i)})),
|
|
|
|
|
xlim = c(1, nrow(H)),
|
|
|
|
|
ylim = c(0, max(H)),
|
|
|
|
|
type = "l"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
call.plot <- match.call()
|
|
|
|
|
keys <- names(defaults)
|
|
|
|
|
keys <- keys[match(keys, names(call.plot)[-1], nomatch = 0) == 0]
|
|
|
|
|
|
|
|
|
|
for (key in keys) {
|
|
|
|
|
call.plot[[key]] <- defaults[[key]]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
call.plot[[1L]] <- quote(plot)
|
|
|
|
|
call.plot$x <- quote(1:length(H_1))
|
|
|
|
|
call.plot$y <- quote(H_1)
|
|
|
|
|
|
|
|
|
|
eval(call.plot)
|
|
|
|
|
|
|
|
|
|
if (ncol(H) > 1) {
|
|
|
|
|
for (i in 2:ncol(H)) {
|
|
|
|
|
H_i <- H[H[, i] > 0, i]
|
|
|
|
|
lines(1:length(H_i), H_i)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
x.ends <- apply(H, 2, function(h) { length(h[!is.na(h)]) })
|
|
|
|
|
y.ends <- apply(H, 2, function(h) { tail(h[!is.na(h)], n=1) })
|
|
|
|
|
points(x.ends, y.ends)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#' Prints a summary of a \code{cve} result.
|
|
|
|
|
#' @param object Instance of 'cve' as return of \code{cve}.
|
|
|
|
|
#' @method summary cve
|
|
|
|
|
#' @export
|
|
|
|
|
summary.cve <- function(object, ...) {
|
|
|
|
|
cat('Summary of CVE result - Method: "', object$method, '"\n',
|
|
|
|
|
'\n',
|
|
|
|
|
'Dataset size: ', nrow(object$X), '\n',
|
|
|
|
|
'Data Dimension: ', ncol(object$X), '\n',
|
|
|
|
|
'SDR Dimension: ', object$k, '\n',
|
|
|
|
|
'loss: ', object$loss, '\n',
|
|
|
|
|
'\n',
|
|
|
|
|
'Called via:\n',
|
|
|
|
|
' ',
|
|
|
|
|
sep='')
|
|
|
|
|
print(object$call)
|
|
|
|
|
}
|