tensor_predictors/tensorPredictors/R/num_deriv.R

104 lines
3.1 KiB
R
Raw Normal View History

#' Numeric differentiation
#'
#' @example inst/examples/num_deriv.R
#'
#' @export
num.deriv <- function(expr, ..., h = 1e-6, sym = FALSE) {
sexpr <- substitute(expr)
if (...length() != 1) {
stop("Expectd exactly one '...' variable")
}
var <- ...names()[1]
if (is.null(var)) {
arg <- substitute(...)
var <- if (is.symbol(arg)) as.character(arg) else "x"
}
if (is.language(sexpr) && !is.symbol(sexpr) && sexpr[[1]] == as.symbol("function")) {
func <- expr
} else {
if (is.name(sexpr)) {
expr <- call(as.character(sexpr), as.name(var))
} else {
if ((!is.call(sexpr) && !is.expression(sexpr))
|| !(var %in% all.vars(sexpr))) {
stop("'expr' must be a function or expression containing '", var, "'")
}
expr <- sexpr
}
args <- as.pairlist(structure(list(alist(x = )[[1]]), names = var))
func <- as.function(c(args, expr), envir = parent.frame())
}
num.deriv.function(func, ..1, h = h, sym = sym)
}
#' @rdname num.deriv
#' @export
num.deriv.function <- function(FUN, X, h = 1e-6, sym = FALSE) {
if (sym) {
stopifnot(isSymmetric(X))
p <- nrow(X)
k <- seq_along(X) - 1
mapply(function(i, j) {
dx <- h * ((k == i * p + j) | (k == j * p + i))
(FUN(X + dx) - FUN(X - dx)) / (2 * h)
}, .row(dim(X)) - 1, .col(dim(X)) - 1)
} else {
sapply(seq_along(X), function(i) {
dx <- h * (seq_along(X) == i)
(FUN(X + dx) - FUN(X - dx)) / (2 * h)
})
}
}
#' @rdname num.deriv
#' @export
num.deriv2 <- function(FUN, X, Y, h = 1e-6, symX = FALSE, symY = FALSE) {
if (missing(Y)) {
num.deriv.function(function(x) {
num.deriv.function(FUN, x, h = h, sym = symX)
}, X, h = h, sym = symX)
} else {
num.deriv.function(function(y) {
num.deriv.function(function(x) FUN(x, y), X, h = h, sym = symX)
}, Y, h = h, sym = symY)
}
}
### Interface Idea / Demo
# num.deriv2.function
# num.deriv2 <- function(expr, var) {
# sexpr <- substitute(expr)
# svar <- substitute(var)
#
# if (is.language(sexpr) && !is.symbol(sexpr) && sexpr[[1]] == as.symbol("function")) {
# func <- expr
# } else {
# if (is.name(sexpr)) {
# expr <- call(as.character(sexpr), as.name(svar))
# } else {
# if ((!is.call(sexpr) && !is.expression(sexpr))
# || !(as.character(svar) %in% all.vars(sexpr))) {
# stop("'expr' must be a function or expression containing '",
# as.character(svar), "'")
# }
# expr <- sexpr
# }
#
# args <- as.pairlist(structure(list(alist(x = )[[1]]), names = as.character(svar)))
# func <- as.function(c(args, expr), envir = parent.frame())
# }
#
# num.deriv2.function(func, var)
# }
# y <- c(pi, exp(1), (sqrt(5) + 1) / 2)
# num.deriv2(function(x) x^2 + y, x)(1:3)
# num.deriv2(x^2 + y, x)(1:3)
# func <- function(x) x^2 + y
# num.deriv2(func, x)(1:3)
# func2 <- function(z = y, x) x^2 + z
# num.deriv2(func2, x)(1:3)