104 lines
3.1 KiB
R
104 lines
3.1 KiB
R
#' Numeric differentiation
|
|
#'
|
|
#' @example inst/examples/num_deriv.R
|
|
#'
|
|
#' @export
|
|
num.deriv <- function(expr, ..., h = 1e-6, sym = FALSE) {
|
|
sexpr <- substitute(expr)
|
|
if (...length() != 1) {
|
|
stop("Expectd exactly one '...' variable")
|
|
}
|
|
var <- ...names()[1]
|
|
if (is.null(var)) {
|
|
arg <- substitute(...)
|
|
var <- if (is.symbol(arg)) as.character(arg) else "x"
|
|
}
|
|
|
|
if (is.language(sexpr) && !is.symbol(sexpr) && sexpr[[1]] == as.symbol("function")) {
|
|
func <- expr
|
|
} else {
|
|
if (is.name(sexpr)) {
|
|
expr <- call(as.character(sexpr), as.name(var))
|
|
} else {
|
|
if ((!is.call(sexpr) && !is.expression(sexpr))
|
|
|| !(var %in% all.vars(sexpr))) {
|
|
stop("'expr' must be a function or expression containing '", var, "'")
|
|
}
|
|
expr <- sexpr
|
|
}
|
|
|
|
args <- as.pairlist(structure(list(alist(x = )[[1]]), names = var))
|
|
func <- as.function(c(args, expr), envir = parent.frame())
|
|
}
|
|
|
|
num.deriv.function(func, ..1, h = h, sym = sym)
|
|
}
|
|
|
|
#' @rdname num.deriv
|
|
#' @export
|
|
num.deriv.function <- function(FUN, X, h = 1e-6, sym = FALSE) {
|
|
if (sym) {
|
|
stopifnot(isSymmetric(X))
|
|
p <- nrow(X)
|
|
k <- seq_along(X) - 1
|
|
mapply(function(i, j) {
|
|
dx <- h * ((k == i * p + j) | (k == j * p + i))
|
|
(FUN(X + dx) - FUN(X - dx)) / (2 * h)
|
|
}, .row(dim(X)) - 1, .col(dim(X)) - 1)
|
|
} else {
|
|
sapply(seq_along(X), function(i) {
|
|
dx <- h * (seq_along(X) == i)
|
|
(FUN(X + dx) - FUN(X - dx)) / (2 * h)
|
|
})
|
|
}
|
|
}
|
|
|
|
#' @rdname num.deriv
|
|
#' @export
|
|
num.deriv2 <- function(FUN, X, Y, h = 1e-6, symX = FALSE, symY = FALSE) {
|
|
if (missing(Y)) {
|
|
num.deriv.function(function(x) {
|
|
num.deriv.function(FUN, x, h = h, sym = symX)
|
|
}, X, h = h, sym = symX)
|
|
} else {
|
|
num.deriv.function(function(y) {
|
|
num.deriv.function(function(x) FUN(x, y), X, h = h, sym = symX)
|
|
}, Y, h = h, sym = symY)
|
|
}
|
|
}
|
|
|
|
|
|
### Interface Idea / Demo
|
|
# num.deriv2.function
|
|
# num.deriv2 <- function(expr, var) {
|
|
# sexpr <- substitute(expr)
|
|
# svar <- substitute(var)
|
|
#
|
|
# if (is.language(sexpr) && !is.symbol(sexpr) && sexpr[[1]] == as.symbol("function")) {
|
|
# func <- expr
|
|
# } else {
|
|
# if (is.name(sexpr)) {
|
|
# expr <- call(as.character(sexpr), as.name(svar))
|
|
# } else {
|
|
# if ((!is.call(sexpr) && !is.expression(sexpr))
|
|
# || !(as.character(svar) %in% all.vars(sexpr))) {
|
|
# stop("'expr' must be a function or expression containing '",
|
|
# as.character(svar), "'")
|
|
# }
|
|
# expr <- sexpr
|
|
# }
|
|
#
|
|
# args <- as.pairlist(structure(list(alist(x = )[[1]]), names = as.character(svar)))
|
|
# func <- as.function(c(args, expr), envir = parent.frame())
|
|
# }
|
|
#
|
|
# num.deriv2.function(func, var)
|
|
# }
|
|
# y <- c(pi, exp(1), (sqrt(5) + 1) / 2)
|
|
# num.deriv2(function(x) x^2 + y, x)(1:3)
|
|
# num.deriv2(x^2 + y, x)(1:3)
|
|
# func <- function(x) x^2 + y
|
|
# num.deriv2(func, x)(1:3)
|
|
# func2 <- function(z = y, x) x^2 + z
|
|
# num.deriv2(func2, x)(1:3)
|