tensor_predictors/tensorPredictors/R/kpir_ls.R

83 lines
2.8 KiB
R
Raw Normal View History

2022-05-11 15:26:37 +00:00
#' Per mode (axis) alternating least squares estimate
#'
#' @param sample.axis index of the sample mode, a.k.a. observation axis index
2022-05-11 15:26:37 +00:00
#'
#' @export
kpir.ls <- function(X, Fy, max.iter = 20L, sample.axis = 1L,
eps = sqrt(.Machine$double.eps), center = TRUE, logger = NULL
2022-05-11 15:26:37 +00:00
) {
### Step 0: Setup/Initialization
2022-05-11 15:26:37 +00:00
if (!is.array(Fy)) {
# scalar response case (add new axis of size 1)
dim(Fy) <- ifelse(seq_along(dim(X)) == sample.axis, dim(X)[sample.axis], 1L)
}
# Check dimensions and matching of axis (tensor order)
stopifnot(exprs = {
length(dim(X)) == length(dim(Fy))
dim(X)[sample.axis] == dim(Fy)[sample.axis]
})
# warn about occurence of an axis without reduction
if (any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])) {
warning("Degenerate case 'any(dim(Fy)[-sample.axis] >= dim(X)[-sample.axis])'")
2022-05-11 15:26:37 +00:00
}
# mode index sequence (exclude sample mode, a.k.a. observation axis)
modes <- seq_along(dim(X))[-sample.axis]
n <- dim(X)[sample.axis] # observation count (scalar)
p <- dim(X)[-sample.axis] # predictor dimensions (vector)
if (center) {
# Means for X and Fy (a.k.a. sum elements over the sample axis)
meanX <- apply(X, modes, mean, simplify = TRUE)
meanFy <- apply(Fy, modes, mean, simplify = TRUE)
# Center both X and Fy
X <- sweep(X, modes, meanX)
Fy <- sweep(Fy, modes, meanFy)
} else {
meanX <- meanFy <- NA
}
2022-05-11 15:26:37 +00:00
2022-05-11 15:26:37 +00:00
### Step 1: initial per mode estimates
alphas <- Map(function(mode, ncol) {
La.svd(mcrossprod(X, mode = mode), ncol)$u
2022-05-11 15:26:37 +00:00
}, modes, dim(Fy)[modes])
### Step 2: iterate per mode (axis) least squares estimates
2022-05-11 15:26:37 +00:00
for (iter in seq_len(max.iter)) {
# Invoke logger for previous iterate
if (is.function(logger)) {
logger("ls", iter - 1L, alphas)
}
2022-05-11 15:26:37 +00:00
# cyclic iterate over modes
for (j in seq_along(modes)) {
# least squares solution for `alpha_j | alpha_i, i != j`
Z <- mlm(Fy, alphas[-j], modes = modes[-j])
alphas[[j]] <- t(solve(
mcrossprod(Z, Z, modes[j]), mcrossprod(Z, X, modes[j])
))
2022-05-11 15:26:37 +00:00
}
# TODO: add some kind of break condition
}
### Step 3: Moment estimates for `Delta_i`
# Residuals
R <- X - mlm(Fy, alphas, modes = modes)
# Moment estimates for `Delta_i`s
Deltas <- Map(mcrossprod, list(R), mode = modes)
Deltas <- Map(`*`, p / (n * prod(p)), Deltas)
# Call logger with final results (including Deltas)
if (is.function(logger)) {
logger("ls", iter, alphas, Deltas)
}
2022-05-11 15:26:37 +00:00
list(alphas = structure(alphas, names = as.character(modes)),
Deltas = structure(Deltas, names = as.character(modes)),
meanX = meanX, meanFy = meanFy)
}