2
0
Fork 0
CVE/test.R

78 lines
2.4 KiB
R
Raw Normal View History

2019-09-16 09:28:06 +00:00
2019-09-25 11:53:45 +00:00
args <- commandArgs(TRUE)
if (length(args) > 0) {
method <- args[1]
} else {
method <- "simple"
}
epochs <- 50L
attempts <- 25L
2019-09-16 09:28:06 +00:00
2019-09-25 11:53:45 +00:00
# library(CVEpureR)
# path <- paste0('~/Projects/CVE/tmp/logger_', method, '.R.pdf')
library(CVE)
path <- paste0('~/Projects/CVE/tmp/logger_', method, '.C.pdf')
2019-09-16 09:28:06 +00:00
2019-09-25 11:53:45 +00:00
# Define logger for `cve()` method.
logger <- function(epoch, attempt, L, V, tau) {
2019-09-16 09:28:06 +00:00
# Note the `<<-` assignement!
2019-09-25 11:53:45 +00:00
loss.history[epoch + 1, attempt] <<- mean(L)
if (epoch == 0) {
error <- NA
} else {
error <- norm(V %*% t(V) - V_last %*% t(V_last), type = 'F')
}
V_last <<- V
error.history[epoch + 1, attempt] <<- error
tau.history[epoch + 1, attempt] <<- tau
2019-09-16 09:28:06 +00:00
# Compute true error by comparing to the true `B`
2019-09-25 11:53:45 +00:00
B.est <- null(V) # Function provided by CVE
2019-09-16 09:28:06 +00:00
P.est <- B.est %*% solve(t(B.est) %*% B.est) %*% t(B.est)
true.error <- norm(P - P.est, 'F') / sqrt(2 * k)
2019-09-25 11:53:45 +00:00
true.error.history[epoch + 1, attempt] <<- true.error
2019-09-16 09:28:06 +00:00
}
pdf(path)
par(mfrow = c(2, 2))
for (name in paste0("M", seq(5))) {
# Seed random number generator
set.seed(42)
# Create a dataset
ds <- dataset(name)
X <- ds$X
Y <- ds$Y
B <- ds$B # the true `B`
k <- ncol(ds$B)
# True projection matrix.
P <- B %*% solve(t(B) %*% B) %*% t(B)
# Setup histories.
2019-09-25 11:53:45 +00:00
V_last <- NULL
2019-09-16 09:28:06 +00:00
loss.history <- matrix(NA, epochs + 1, attempts)
error.history <- matrix(NA, epochs + 1, attempts)
tau.history <- matrix(NA, epochs + 1, attempts)
true.error.history <- matrix(NA, epochs + 1, attempts)
2019-09-25 11:53:45 +00:00
dr <- cve(Y ~ X, k = k, method = method,
epochs = epochs, attempts = attempts,
logger = logger)
2019-09-16 09:28:06 +00:00
# Plot history's
matplot(loss.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('loss', name),
ylab = expression(L(V[i])))
matplot(true.error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('true error', name),
2019-09-25 11:53:45 +00:00
ylab = expression(group('|', B*B^T - B[i]*B[i]^T, '|')[F] / sqrt(2*k)))
2019-09-16 09:28:06 +00:00
matplot(error.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('error', name),
2019-09-25 11:53:45 +00:00
ylab = expression(group('|', V[i-1]*V[i-1]^T - V[i]*V[i]^T, '|')[F]))
2019-09-16 09:28:06 +00:00
matplot(tau.history, type = 'l', log = 'y', xlab = 'i (iteration)',
main = paste('learning rate', name),
ylab = expression(tau[i]))
2019-09-25 11:53:45 +00:00
}
cat("Created plot:", path, "\n")