44 lines
1.6 KiB
R
44 lines
1.6 KiB
R
|
library(CVEpureR)
|
||
|
|
||
|
# Setup histories.
|
||
|
(epochs <- 50)
|
||
|
(attempts <- 10)
|
||
|
loss.history <- matrix(NA, epochs + 1, attempts)
|
||
|
error.history <- matrix(NA, epochs + 1, attempts)
|
||
|
tau.history <- matrix(NA, epochs + 1, attempts)
|
||
|
true.error.history <- matrix(NA, epochs + 1, attempts)
|
||
|
|
||
|
# Create a dataset
|
||
|
ds <- dataset("M1")
|
||
|
X <- ds$X
|
||
|
Y <- ds$Y
|
||
|
B <- ds$B # the true `B`
|
||
|
(k <- ncol(ds$B))
|
||
|
|
||
|
# True projection matrix.
|
||
|
P <- B %*% solve(t(B) %*% B) %*% t(B)
|
||
|
# Define the logger for the `cve()` method.
|
||
|
logger <- function(env) {
|
||
|
# Note the `<<-` assignement!
|
||
|
loss.history[env$epoch + 1, env$attempt] <<- env$loss
|
||
|
error.history[env$epoch + 1, env$attempt] <<- env$error
|
||
|
tau.history[env$epoch + 1, env$attempt] <<- env$tau
|
||
|
# Compute true error by comparing to the true `B`
|
||
|
B.est <- null(env$V) # Function provided by CVE
|
||
|
P.est <- B.est %*% solve(t(B.est) %*% B.est) %*% t(B.est)
|
||
|
true.error <- norm(P - P.est, 'F') / sqrt(2 * k)
|
||
|
true.error.history[env$epoch + 1, env$attempt] <<- true.error
|
||
|
}
|
||
|
# Performe SDR for ONE `k`.
|
||
|
dr <- cve(Y ~ X, k = k, logger = logger, epochs = epochs, attempts = attempts)
|
||
|
# Plot history's
|
||
|
par(mfrow = c(2, 2))
|
||
|
matplot(loss.history, type = 'l', log = 'y', xlab = 'iter',
|
||
|
main = 'loss', ylab = expression(L(V[iter])))
|
||
|
matplot(error.history, type = 'l', log = 'y', xlab = 'iter',
|
||
|
main = 'error', ylab = 'error')
|
||
|
matplot(tau.history, type = 'l', log = 'y', xlab = 'iter',
|
||
|
main = 'tau', ylab = 'tau')
|
||
|
matplot(true.error.history, type = 'l', log = 'y', xlab = 'iter',
|
||
|
main = 'true error', ylab = 'true error')
|