2
0
Fork 0
CVE/simulations/hitters_logging.R

109 lines
3.5 KiB
R

library(ISLR) # for Hitters dataset
library(MAVE)
library(CVE)
# Set global parameters.
seed <- 21
max.dim <- 5L
attempts <- 25L
max.iter <- 100L
momentum <- 0.0
name <- "Hitters"
set.seed(seed)
# Prepair data for analysis.
cols <- c("Salary", "AtBat", "Hits", "HmRun", "Runs", "RBI",
"Walks", "Years", "CAtBat","CHits", "CHmRun","CRuns",
"CRBI", "CWalks", "PutOuts", "Assists", "Errors")
outliers <- c(92, # Gary Pettis
120, # John Moses
173, # Milt Thompson
189, # Rick Burleson
220, # Scott Fletcher
230, # Tom Foley
241) # Terry Puhl
# Subselect as a matrix without outliers as well as reordered
# and filtered columns.
ds <- na.omit(Hitters[, cols])[-outliers, ]
ds$Salary <- log(ds$Salary)
ds <- scale(ds, center = TRUE, scale = TRUE)
# Split into data and responce.
X <- as.matrix(ds[, colnames(ds) != "Salary"])
Y <- as.matrix(ds[, "Salary"])
path <- file.path(getwd(), 'results', 'hitters_logger.pdf')
pdf(path, width = 8.27, height = 11.7) # width, height unit is inces -> A4
layout(matrix(c(1, 1,
2, 3,
4, 5), nrow = 3, byrow = TRUE))
# Setup histories.
loss.history <- matrix(NA, max.iter + 1, attempts)
error.history <- matrix(NA, max.iter + 1, attempts)
tau.history <- matrix(NA, max.iter + 1, attempts)
grad.norm.history <- matrix(NA, max.iter + 1, attempts)
# Define logger for `cve()` method.
logger <- function(attempt, iter, data) {
# Note the `<<-` assignement!
loss.history[iter + 1, attempt] <<- data$loss
error.history[iter + 1, attempt] <<- data$err
tau.history[iter + 1, attempt] <<- data$tau
grad.norm.history[iter + 1, attempt] <<- norm(data$G, 'F')
}
dr <- cve(Y ~ X, k = 2L, max.iter = max.iter, attempts = attempts,
logger = logger)
B <- coef(dr, 2L)
loss <- dr$res[['2']]$loss
textplot <- function(...) {
text <- unlist(list(...))
if (length(text) > 20) {
text <- c(text[1:17],
' ...... (skipped, text too long) ......',
text[c(-1, 0) + length(text)])
}
plot(NA, xlim = c(0, 1), ylim = c(0, 1),
bty = 'n', xaxt = 'n', yaxt = 'n', xlab = '', ylab = '')
for (i in seq_along(text)) {
text(0, 1 - (i / 20),
text[[i]], pos = 4)
}
}
# Write metadata.
textplot(
paste0("Seed value: ", seed),
"",
paste0("Dataset Name: ", name),
paste0("dim(X) = (", nrow(X), ", ", ncol(X), ")"),
paste0("dim(B) = (", nrow(B), ", ", ncol(B), ")"),
"",
paste0("CVE method: ", dr$method),
paste0("Max Iterations: ", max.iter),
paste0("Attempts: ", attempts),
paste0("Momentum: ", momentum),
"CVE call:",
paste0(" > ", format(dr$call)),
"",
paste0("True Error: ", NA),
paste0("loss: ", round(loss, 3))
)
# Plot history's
matplot(loss.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('loss', name),
ylab = expression(L(V[i])))
matplot(grad.norm.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('gradient norm', name),
ylab = expression(group('|', paste(nabla, L(V)), '|')[F]))
matplot(error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('error', name),
ylab = expression(group('|', V[i-1]*V[i-1]^T - V[i]*V[i]^T, '|')[F]))
matplot(tau.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('learning rate', name),
ylab = expression(tau[i]))