109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			109 lines
		
	
	
		
			3.5 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
library(ISLR) # for Hitters dataset
 | 
						|
library(MAVE)
 | 
						|
library(CVE)
 | 
						|
 | 
						|
# Set global parameters.
 | 
						|
seed <- 21
 | 
						|
max.dim <- 5L
 | 
						|
attempts <- 25L
 | 
						|
max.iter <- 100L
 | 
						|
momentum <- 0.0
 | 
						|
name <- "Hitters"
 | 
						|
set.seed(seed)
 | 
						|
 | 
						|
# Prepair data for analysis.
 | 
						|
cols <- c("Salary", "AtBat", "Hits", "HmRun", "Runs", "RBI",
 | 
						|
          "Walks", "Years", "CAtBat","CHits", "CHmRun","CRuns",
 | 
						|
          "CRBI", "CWalks", "PutOuts", "Assists", "Errors")
 | 
						|
outliers <- c(92,  # Gary Pettis
 | 
						|
              120, # John Moses
 | 
						|
              173, # Milt Thompson
 | 
						|
              189, # Rick Burleson
 | 
						|
              220, # Scott Fletcher
 | 
						|
              230, # Tom Foley
 | 
						|
              241) # Terry Puhl
 | 
						|
# Subselect as a matrix without outliers as well as reordered
 | 
						|
# and filtered columns.
 | 
						|
ds <- na.omit(Hitters[, cols])[-outliers, ]
 | 
						|
ds$Salary <- log(ds$Salary)
 | 
						|
ds <- scale(ds, center = TRUE, scale = TRUE)
 | 
						|
# Split into data and responce.
 | 
						|
X <- as.matrix(ds[, colnames(ds) != "Salary"])
 | 
						|
Y <- as.matrix(ds[, "Salary"])
 | 
						|
 | 
						|
path <- file.path(getwd(), 'results', 'hitters_logger.pdf')
 | 
						|
pdf(path, width = 8.27, height = 11.7) # width, height unit is inces -> A4
 | 
						|
layout(matrix(c(1, 1,
 | 
						|
                2, 3,
 | 
						|
                4, 5), nrow = 3, byrow = TRUE))
 | 
						|
 | 
						|
 | 
						|
# Setup histories.
 | 
						|
loss.history       <- matrix(NA, max.iter + 1, attempts)
 | 
						|
error.history      <- matrix(NA, max.iter + 1, attempts)
 | 
						|
tau.history        <- matrix(NA, max.iter + 1, attempts)
 | 
						|
grad.norm.history  <- matrix(NA, max.iter + 1, attempts)
 | 
						|
# Define logger for `cve()` method.
 | 
						|
logger <- function(attempt, iter, data) {
 | 
						|
    # Note the `<<-` assignement!
 | 
						|
    loss.history[iter + 1, attempt] <<- data$loss
 | 
						|
    error.history[iter + 1, attempt] <<- data$err
 | 
						|
    tau.history[iter + 1, attempt] <<- data$tau
 | 
						|
    grad.norm.history[iter + 1, attempt] <<- norm(data$G, 'F')
 | 
						|
}
 | 
						|
 | 
						|
dr <- cve(Y ~ X, k = 2L, max.iter = max.iter, attempts = attempts,
 | 
						|
        logger = logger)
 | 
						|
B <- coef(dr, 2L)
 | 
						|
loss <- dr$res[['2']]$loss
 | 
						|
 | 
						|
 | 
						|
 | 
						|
textplot <- function(...) {
 | 
						|
    text <- unlist(list(...))
 | 
						|
    if (length(text) > 20) {
 | 
						|
        text <- c(text[1:17],
 | 
						|
                  '   ...... (skipped, text too long) ......',
 | 
						|
                  text[c(-1, 0) + length(text)])
 | 
						|
    }
 | 
						|
 | 
						|
    plot(NA, xlim = c(0, 1), ylim = c(0, 1),
 | 
						|
         bty = 'n', xaxt = 'n', yaxt = 'n', xlab = '', ylab = '')
 | 
						|
        
 | 
						|
    for (i in seq_along(text)) {
 | 
						|
        text(0, 1 - (i / 20),
 | 
						|
             text[[i]], pos = 4)
 | 
						|
    }
 | 
						|
}
 | 
						|
# Write metadata.
 | 
						|
textplot(
 | 
						|
    paste0("Seed value: ",     seed),
 | 
						|
            "",
 | 
						|
    paste0("Dataset Name: ",   name),
 | 
						|
    paste0("dim(X) = (", nrow(X), ", ", ncol(X), ")"),
 | 
						|
    paste0("dim(B) = (", nrow(B), ", ", ncol(B), ")"),
 | 
						|
            "",
 | 
						|
    paste0("CVE method: ",     dr$method),
 | 
						|
    paste0("Max Iterations: ", max.iter),
 | 
						|
    paste0("Attempts: ",       attempts),
 | 
						|
    paste0("Momentum: ",       momentum),
 | 
						|
            "CVE call:",
 | 
						|
            paste0("  > ", format(dr$call)),
 | 
						|
            "",
 | 
						|
    paste0("True Error: ", NA),
 | 
						|
    paste0("loss: ",       round(loss, 3))
 | 
						|
)
 | 
						|
# Plot history's
 | 
						|
matplot(loss.history,       type = 'l', log = 'y', xlab = 'i (iteration)',
 | 
						|
    main = paste('loss', name),
 | 
						|
    ylab = expression(L(V[i])))
 | 
						|
matplot(grad.norm.history, type = 'l', log = 'y', xlab = 'i (iteration)',
 | 
						|
    main = paste('gradient norm', name),
 | 
						|
    ylab = expression(group('|', paste(nabla, L(V)), '|')[F]))
 | 
						|
matplot(error.history,      type = 'l', log = 'y', xlab = 'i (iteration)',
 | 
						|
    main = paste('error', name),
 | 
						|
    ylab = expression(group('|', V[i-1]*V[i-1]^T - V[i]*V[i]^T, '|')[F]))
 | 
						|
matplot(tau.history,        type = 'l', log = 'y', xlab = 'i (iteration)',
 | 
						|
    main = paste('learning rate', name),
 | 
						|
    ylab = expression(tau[i]))
 |