
71 lines
2.1 KiB
Raw Normal View History

# Rosenbrock function for x in R^2
fun <- function(x, a = 1, b = 100) {
(a - x[1])^2 + b * (x[2] - x[1]^2)^2
# Gradient of the Rosenbrock function
grad <- function(x, a = 1, b = 100) {
2 * c(x[1] - a - b * x[1] * (x[2] - x[1]^2), b * (x[2] - x[1]^2))
# call with initial values (x, y) = (-1, 1)
NAGD(fun, grad, c(-1, 1), max.iter = 500L),
c(1, 1) # known minimum
# Equivalent to above, but the parameters are in a list
fun <- function(params, a = 1, b = 100) {
(a - params$x)^2 + b * (params$y - params$x^2)^2
grad <- function(params, a = 1, b = 100) list(
x = 2 * (params$x - a - b * params$x * (params$y - params$x^2)),
y = 2 * b * (params$y - params$x^2)
# need to tell NAGD how to combine parameters
lincomb <- function(a, LHS, b, RHS) list(
x = a * LHS$x + b * RHS$x,
y = a * LHS$y + b * RHS$y
# and how to compute there norm (squared)
norm2 <- function(params) {
# callback invoced for each update
callback <- function(iter, params) {
cat(sprintf("%3d - fun(%7.4f, %7.4f) = %6.4f\n",
iter, params$x, params$y, fun(params)))
# call with initial values (x, y) = (-1, 1)
fit <- NAGD(fun, grad, list(x = -1, y = 1),
fun.lincomb = lincomb, fun.norm2 = norm2,
callback = callback)
# Weighted Least Squares for Heterosgedastic Data
# Predictors
x <- rnorm(500)
# "True" parameters
beta <- c(intercept = 1, slope = 0.5)
# Model matrix
X <- cbind(1, x)
# response + heterosgedastic noise
y <- X %*% beta + sqrt(x - min(x) + 0.1) * rnorm(length(x))
loss <- function(beta, w) {
sum((y - X %*% beta)^2 * w)
weights <- function(beta, w, delta = 1e-3) {
1 / pmax(abs(y - X %*% beta), delta)
grad <- function(beta, w) {
-2 * crossprod(X, (y - X %*% beta) * w)
fit <- NAGD(loss, grad, coef(lm(y ~ x)), more.params = 1, fun.more.params = weights)
# # plot the data
# plot(x, y)
# abline(beta[1], beta[2], col = "black", lty = 2, lwd = 2)
# beta.hat.lm <- coef(lm(y ~ x))
# abline(beta.hat.lm[1], beta.hat.lm[2], col = "red", lwd = 2)
# beta.hat.wls <- fit$params
# abline(beta.hat.wls[1], beta.hat.wls[2], col = "blue", lwd = 2)