121 lines
3.9 KiB
R
121 lines
3.9 KiB
R
textplot <- function(...) {
|
|
text <- unlist(list(...))
|
|
if (length(text) > 20) {
|
|
text <- c(text[1:17],
|
|
' ...... (skipped, text too long) ......',
|
|
text[c(-1, 0) + length(text)])
|
|
}
|
|
|
|
plot(NA, xlim = c(0, 1), ylim = c(0, 1),
|
|
bty = 'n', xaxt = 'n', yaxt = 'n', xlab = '', ylab = '')
|
|
|
|
for (i in seq_along(text)) {
|
|
text(0, 1 - (i / 20),
|
|
text[[i]], pos = 4)
|
|
}
|
|
}
|
|
|
|
args <- commandArgs(TRUE)
|
|
if (length(args) > 0L) {
|
|
method <- args[1]
|
|
} else {
|
|
method <- "simple"
|
|
}
|
|
if (length(args) > 1L) {
|
|
momentum <- as.double(args[2])
|
|
} else {
|
|
momentum <- 0.0
|
|
}
|
|
seed <- 42
|
|
max.iter <- 50L
|
|
attempts <- 25L
|
|
|
|
library(CVE)
|
|
path <- paste0('~/Projects/CVE/tmp/logger_', method, '.C.pdf')
|
|
|
|
# Define logger for `cve()` method.
|
|
logger <- function(attempt, iter, data) {
|
|
# Note the `<<-` assignement!
|
|
loss.history[iter + 1, attempt] <<- data$loss
|
|
error.history[iter + 1, attempt] <<- data$err
|
|
tau.history[iter + 1, attempt] <<- data$tau
|
|
# Compute true error by comparing to the true `B`
|
|
B.est <- null(data$V) # Function provided by CVE
|
|
P.est <- B.est %*% solve(t(B.est) %*% B.est) %*% t(B.est)
|
|
true.error <- norm(P - P.est, 'F') / sqrt(2 * k)
|
|
true.error.history[iter + 1, attempt] <<- true.error
|
|
}
|
|
|
|
pdf(path, width = 8.27, height = 11.7) # width, height unit is inces -> A4
|
|
layout(matrix(c(1, 1,
|
|
2, 3,
|
|
4, 5), nrow = 3, byrow = TRUE))
|
|
|
|
for (name in paste0("M", seq(7))) {
|
|
# Seed random number generator
|
|
set.seed(seed)
|
|
|
|
# Create a dataset
|
|
ds <- dataset(name)
|
|
X <- ds$X
|
|
Y <- ds$Y
|
|
B <- ds$B # the true `B`
|
|
k <- ncol(ds$B)
|
|
# True projection matrix.
|
|
P <- B %*% solve(t(B) %*% B) %*% t(B)
|
|
|
|
# Setup histories.
|
|
V_last <- NULL
|
|
loss.history <- matrix(NA, max.iter + 1, attempts)
|
|
error.history <- matrix(NA, max.iter + 1, attempts)
|
|
tau.history <- matrix(NA, max.iter + 1, attempts)
|
|
true.error.history <- matrix(NA, max.iter + 1, attempts)
|
|
|
|
time <- system.time(
|
|
dr <- cve(Y ~ X, k = k, method = method,
|
|
momentum = momentum,
|
|
max.iter = max.iter, attempts = attempts,
|
|
logger = logger)
|
|
)["elapsed"]
|
|
|
|
# Extract finaly selected values:
|
|
B.est <- coef(dr, k)
|
|
true.error <- norm(tcrossprod(B.est) - tcrossprod(B), 'F') / sqrt(2 * k)
|
|
loss <- dr$res[[as.character(k)]]$loss
|
|
|
|
# Write metadata.
|
|
textplot(
|
|
paste0("Seed value: ", seed),
|
|
"",
|
|
paste0("Dataset Name: ", ds$name),
|
|
paste0("dim(X) = (", nrow(X), ", ", ncol(X), ")"),
|
|
paste0("dim(B) = (", nrow(B), ", ", ncol(B), ")"),
|
|
"",
|
|
paste0("CVE method: ", dr$method),
|
|
paste0("Max Iterations: ", max.iter),
|
|
paste0("Attempts: ", attempts),
|
|
paste0("Momentum: ", momentum),
|
|
"CVE call:",
|
|
paste0(" > ", format(dr$call)),
|
|
"",
|
|
paste0("True Error: ", round(true.error, 3)),
|
|
paste0("loss: ", round(loss, 3)),
|
|
paste0("time: ", round(time, 3), " s")
|
|
)
|
|
# Plot history's
|
|
matplot(loss.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('loss', name),
|
|
ylab = expression(L(V[i])))
|
|
matplot(true.error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('true error', name),
|
|
ylab = expression(group('|', B*B^T - B[i]*B[i]^T, '|')[F] / sqrt(2*k)))
|
|
matplot(error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('error', name),
|
|
ylab = expression(group('|', V[i-1]*V[i-1]^T - V[i]*V[i]^T, '|')[F]))
|
|
matplot(tau.history, type = 'l', log = 'y', xlab = 'i (iteration)',
|
|
main = paste('learning rate', name),
|
|
ylab = expression(tau[i]))
|
|
}
|
|
|
|
cat("Created plot:", path, "\n")
|