# Rosenbrock function for x in R^2 fun <- function(x, a = 1, b = 100) { (a - x[1])^2 + b * (x[2] - x[1]^2)^2 } # Gradient of the Rosenbrock function grad <- function(x, a = 1, b = 100) { 2 * c(x[1] - a - b * x[1] * (x[2] - x[1]^2), b * (x[2] - x[1]^2)) } # call with initial values (x, y) = (-1, 1) stopifnot(all.equal( NAGD(fun, grad, c(-1, 1), max.iter = 500L), c(1, 1) # known minimum )) # Equivalent to above, but the parameters are in a list fun <- function(params, a = 1, b = 100) { (a - params$x)^2 + b * (params$y - params$x^2)^2 } grad <- function(params, a = 1, b = 100) list( x = 2 * (params$x - a - b * params$x * (params$y - params$x^2)), y = 2 * b * (params$y - params$x^2) ) # need to tell NAGD how to combine parameters lincomb <- function(a, LHS, b, RHS) list( x = a * LHS$x + b * RHS$x, y = a * LHS$y + b * RHS$y ) # and how to compute there norm (squared) norm2 <- function(params) { sum(unlist(params)^2) } # callback invoced for each update callback <- function(iter, params) { cat(sprintf("%3d - fun(%7.4f, %7.4f) = %6.4f\n", iter, params$x, params$y, fun(params))) } # call with initial values (x, y) = (-1, 1) fit <- NAGD(fun, grad, list(x = -1, y = 1), fun.lincomb = lincomb, fun.norm2 = norm2, callback = callback) # Weighted Least Squares for Heterosgedastic Data # Predictors x <- rnorm(500) # "True" parameters beta <- c(intercept = 1, slope = 0.5) # Model matrix X <- cbind(1, x) # response + heterosgedastic noise y <- X %*% beta + sqrt(x - min(x) + 0.1) * rnorm(length(x)) loss <- function(beta, w) { sum((y - X %*% beta)^2 * w) } weights <- function(beta, w, delta = 1e-3) { 1 / pmax(abs(y - X %*% beta), delta) } grad <- function(beta, w) { -2 * crossprod(X, (y - X %*% beta) * w) } fit <- NAGD(loss, grad, coef(lm(y ~ x)), more.params = 1, fun.more.params = weights) # # plot the data # plot(x, y) # abline(beta[1], beta[2], col = "black", lty = 2, lwd = 2) # beta.hat.lm <- coef(lm(y ~ x)) # abline(beta.hat.lm[1], beta.hat.lm[2], col = "red", lwd = 2) # beta.hat.wls <- fit$params # abline(beta.hat.wls[1], beta.hat.wls[2], col = "blue", lwd = 2)