add: TSIR returns intenrnal variables as attr,

fix: numeric stable loss cals also for loss initialization
This commit is contained in:
Daniel Kapla 2024-01-10 17:31:52 +01:00
parent daefd3e7d1
commit 636eebf720
4 changed files with 10 additions and 6 deletions

5
.gitignore vendored
View File

@ -111,9 +111,8 @@ simulations/
mlda_analysis/ mlda_analysis/
References/ References/
dataAnalysis/* dataAnalysis/chess/*.Rdata
!dataAnalysis/chess/ dataAnalysis/Classification of EEG/
dataAnalysis/chess/*.fen
*.csv *.csv
*.csv.log *.csv.log

View File

@ -58,7 +58,6 @@ export(kpir.momentum)
export(kpir.new) export(kpir.new)
export(kronperm) export(kronperm)
export(mat) export(mat)
export(matProj)
export(matpow) export(matpow)
export(matrixImage) export(matrixImage)
export(mcov) export(mcov)

View File

@ -109,5 +109,5 @@ TSIR <- function(X, y, d, sample.axis = 1L,
# reductions matrices `Omega_k^-1 Gamma_k` where there (reverse) kronecker # reductions matrices `Omega_k^-1 Gamma_k` where there (reverse) kronecker
# product spans the central tensor subspace (CTS) estimate # product spans the central tensor subspace (CTS) estimate
Map(solve, Omegas, Gammas) structure(Map(solve, Omegas, Gammas), mcov = Omegas, Gammas = Gammas)
} }

View File

@ -61,8 +61,14 @@ gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
# Residuals # Residuals
R <- X - mlm(F, Map(`%*%`, Sigmas, betas)) R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
# `Omega <- Reduce(kronecker, rev(Omegas))`.
det.Omega <- sum(mapply(function(Omega) {
sum(log(eigen(Omega, TRUE, TRUE)$values))
}, Omegas) / dimX)
# Initial value of the log-likelihood (scaled and constants dropped) # Initial value of the log-likelihood (scaled and constants dropped)
loss <- mean(R * mlm(R, Omegas)) - sum(log(mapply(det, Omegas)) / dimX) loss <- mean(R * mlm(R, Omegas)) - det.Omega
# invoke the logger # invoke the logger
if (is.function(logger)) do.call(logger, list( if (is.function(logger)) do.call(logger, list(