109 lines
3.5 KiB
R
109 lines
3.5 KiB
R
library(ISLR) # for Hitters dataset
|
|
library(MAVE)
|
|
library(CVE)
|
|
|
|
# Set global parameters.
|
|
seed <- 21
|
|
max.dim <- 5L
|
|
attempts <- 25L
|
|
max.iter <- 100L
|
|
momentum <- 0.0
|
|
name <- "Hitters"
|
|
set.seed(seed)
|
|
|
|
# Prepair data for analysis.
|
|
cols <- c("Salary", "AtBat", "Hits", "HmRun", "Runs", "RBI",
|
|
"Walks", "Years", "CAtBat","CHits", "CHmRun","CRuns",
|
|
"CRBI", "CWalks", "PutOuts", "Assists", "Errors")
|
|
outliers <- c(92, # Gary Pettis
|
|
120, # John Moses
|
|
173, # Milt Thompson
|
|
189, # Rick Burleson
|
|
220, # Scott Fletcher
|
|
230, # Tom Foley
|
|
241) # Terry Puhl
|
|
# Subselect as a matrix without outliers as well as reordered
|
|
# and filtered columns.
|
|
ds <- na.omit(Hitters[, cols])[-outliers, ]
|
|
ds$Salary <- log(ds$Salary)
|
|
ds <- scale(ds, center = TRUE, scale = TRUE)
|
|
# Split into data and responce.
|
|
X <- as.matrix(ds[, colnames(ds) != "Salary"])
|
|
Y <- as.matrix(ds[, "Salary"])
|
|
|
|
path <- file.path(getwd(), 'results', 'hitters_logger.pdf')
|
|
pdf(path, width = 8.27, height = 11.7) # width, height unit is inces -> A4
|
|
layout(matrix(c(1, 1,
|
|
2, 3,
|
|
4, 5), nrow = 3, byrow = TRUE))
|
|
|
|
|
|
# Setup histories.
|
|
loss.history <- matrix(NA, max.iter + 1, attempts)
|
|
error.history <- matrix(NA, max.iter + 1, attempts)
|
|
tau.history <- matrix(NA, max.iter + 1, attempts)
|
|
grad.norm.history <- matrix(NA, max.iter + 1, attempts)
|
|
# Define logger for `cve()` method.
|
|
logger <- function(attempt, iter, data) {
|
|
# Note the `<<-` assignement!
|
|
loss.history[iter + 1, attempt] <<- data$loss
|
|
error.history[iter + 1, attempt] <<- data$err
|
|
tau.history[iter + 1, attempt] <<- data$tau
|
|
grad.norm.history[iter + 1, attempt] <<- norm(data$G, 'F')
|
|
}
|
|
|
|
dr <- cve(Y ~ X, k = 2L, max.iter = max.iter, attempts = attempts,
|
|
logger = logger)
|
|
B <- coef(dr, 2L)
|
|
loss <- dr$res[['2']]$loss
|
|
|
|
|
|
|
|
textplot <- function(...) {
|
|
text <- unlist(list(...))
|
|
if (length(text) > 20) {
|
|
text <- c(text[1:17],
|
|
' ...... (skipped, text too long) ......',
|
|
text[c(-1, 0) + length(text)])
|
|
}
|
|
|
|
plot(NA, xlim = c(0, 1), ylim = c(0, 1),
|
|
bty = 'n', xaxt = 'n', yaxt = 'n', xlab = '', ylab = '')
|
|
|
|
for (i in seq_along(text)) {
|
|
text(0, 1 - (i / 20),
|
|
text[[i]], pos = 4)
|
|
}
|
|
}
|
|
# Write metadata.
|
|
textplot(
|
|
paste0("Seed value: ", seed),
|
|
"",
|
|
paste0("Dataset Name: ", name),
|
|
paste0("dim(X) = (", nrow(X), ", ", ncol(X), ")"),
|
|
paste0("dim(B) = (", nrow(B), ", ", ncol(B), ")"),
|
|
"",
|
|
paste0("CVE method: ", dr$method),
|
|
paste0("Max Iterations: ", max.iter),
|
|
paste0("Attempts: ", attempts),
|
|
paste0("Momentum: ", momentum),
|
|
"CVE call:",
|
|
paste0(" > ", format(dr$call)),
|
|
"",
|
|
paste0("True Error: ", NA),
|
|
paste0("loss: ", round(loss, 3))
|
|
)
|
|
# Plot history's
|
|
matplot(loss.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('loss', name),
|
|
ylab = expression(L(V[i])))
|
|
matplot(grad.norm.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('gradient norm', name),
|
|
ylab = expression(group('|', paste(nabla, L(V)), '|')[F]))
|
|
matplot(error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('error', name),
|
|
ylab = expression(group('|', V[i-1]*V[i-1]^T - V[i]*V[i]^T, '|')[F]))
|
|
matplot(tau.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('learning rate', name),
|
|
ylab = expression(tau[i]))
|