tensor_predictors/tensorPredictors/R/mlm.R

138 lines
4.3 KiB
R
Raw Normal View History

2022-05-11 15:26:37 +00:00
#' Multi Linear Multiplication
#'
#' \deqn{C\times\{ B_1, ..., B_r \}}{%
#' C = A x { B1, ..., Br }}
2022-05-11 15:26:37 +00:00
#'
#' @param A tensor (multi-linear array)
#' @param Bs matrix or list of matrices
#' @param modes integer sequence of the same length as `Bs` specifying the
#' multiplication axis (defaults to `seq_along(Bs)`)
#' @param transposed single boolean or boolean vector of same length as \code{Bs}
#' to transpose the \code{Bs} of matching index before multiplication.
2022-05-11 15:26:37 +00:00
#'
#' @examples
#' # general usage
#' dimA <- c(3, 17, 19, 2)
#' dimC <- c(7, 11, 13, 5)
#' A <- array(rnorm(prod(dimA)), dim = dimA)
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimC, dimA)
#' C1 <- mlm(A, Bs)
#' C2 <- mlm(A, Bs)
#' C3 <- mlm(A, Bs[c(3, 1, 2, 4)], modes = c(3, 1, 2, 4))
2022-05-11 15:26:37 +00:00
#' stopifnot(all.equal(C1, C2))
#' stopifnot(all.equal(C1, C3))
#'
#' # selected modes
#' stopifnot(all.equal(
#' mlm(A, Bs[2:3], modes = 2:3),
#' ttm(ttm(A, Bs[[2]], 2), Bs[[3]], 3)
#' ))
2022-05-11 15:26:37 +00:00
#'
#' # analog to matrix multiplication
#' A <- matrix(rnorm( 6), 2, 3)
#' B <- matrix(rnorm(12), 3, 4)
#' C <- matrix(rnorm(20), 5, 4)
2022-05-11 15:26:37 +00:00
#' stopifnot(all.equal(
#' A %*% B %*% t(C),
#' mlm(B, list(A, C))
#' ))
#'
#' # usage of transposed
#' A <- matrix(rnorm( 6), 2, 3)
#' B <- matrix(rnorm(15), 3, 5)
#' C <- matrix(rnorm(35), 5, 7)
#'
#' stopifnot(all.equal(
#' A %*% B %*% C,
#' mlm(B, list(A, C), transposed = c(FALSE, TRUE))
#' ))
#'
2022-05-11 15:26:37 +00:00
#' # usage with repeated modes (non commutative)
#' dimA <- c(3, 17, 19, 2)
#' A <- array(rnorm(prod(dimA)), dim = dimA)
#' B1 <- matrix(rnorm(9), 3, 3)
#' B2 <- matrix(rnorm(9), 3, 3)
#' C <- matrix(rnorm(4), 2, 2)
#' # same modes do NOT commute
#' all.equal(
#' mlm(A, list(B1, B2, C), c(1, 1, 4)), # NOT equal!
#' mlm(A, list(B2, B1, C), c(1, 1, 4))
2022-05-11 15:26:37 +00:00
#' )
#' # but different modes do commute
#' P1 <- mlm(A, list(C, B1, B2), c(4, 1, 1))
#' P2 <- mlm(A, list(B1, C, B2), c(1, 4, 1))
#' P3 <- mlm(A, list(B1, B2, C), c(1, 1, 4))
2022-05-11 15:26:37 +00:00
#' stopifnot(all.equal(P1, P2))
#' stopifnot(all.equal(P1, P3))
#'
#' # Concatination of MLM is MLM
#' dimX <- c(4, 8, 6, 3)
#' dimA <- c(3, 17, 19, 2)
#' dimB <- c(7, 11, 13, 5)
#' X <- array(rnorm(prod(dimX)), dim = dimX)
#' As <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimA, dimX)
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimB, dimA)
#' # (X x {A1, A2, A3, A4}) x {B1, B2, B3, B4} = X x {B1 A1, B2 A2, B3 A3, B4 A4}
#' all.equal(mlm(mlm(X, As), Bs), mlm(X, Map(`%*%`, Bs, As)))
#'
2023-11-14 13:35:43 +00:00
#' # Equivalent to
#' mlm_reference <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
#' # Collect all matrices in `B`
#' Bs <- if (is.matrix(Bs)) list(Bs) else Bs
#'
#' # replicate transposition if of length one only
#' transposed <- if (length(transposed) == 1) {
#' rep(as.logical(transposed), length(Bs))
#' } else {
#' as.logical(transposed)
#' }
#'
#' # iteratively apply Tensor Times Matrix multiplication over modes
#' for (i in seq_along(modes)) {
#' A <- ttm(A, Bs[[i]], modes[i], transposed[i])
#' }
#'
#' # return result tensor
#' A
#' }
#'
2022-05-11 15:26:37 +00:00
#' @export
mlm <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
2022-05-11 15:26:37 +00:00
# Collect all matrices in `B`
2023-11-14 13:35:43 +00:00
Bs <- if (!is.list(Bs)) list(Bs) else Bs
# ensure all `B`s are matrices
Bs <- Map(as.matrix, Bs)
# replicate transposition if of length one only
transposed <- if (length(transposed) == 1) {
rep(as.logical(transposed), length(Bs))
2023-11-14 13:35:43 +00:00
} else if (length(transposed) == length(modes)) {
as.logical(transposed)
2023-11-14 13:35:43 +00:00
} else {
stop("Dim missmatch of param. `transposed`")
}
2022-05-11 15:26:37 +00:00
2023-11-14 13:35:43 +00:00
.Call("C_mlm", A, Bs, as.integer(modes), transposed, PACKAGE = "tensorPredictors")
2022-05-11 15:26:37 +00:00
}
2023-11-14 13:35:43 +00:00
# # general usage
# dimA <- c(3, 17, 19, 2)
# dimC <- c(7, 11, 13, 5)
# A <- array(rnorm(prod(dimA)), dim = dimA)
# trans <- c(TRUE, FALSE, TRUE, FALSE)
# Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), ifelse(trans, dimA, dimC), ifelse(trans, dimC, dimA))
# C <- mlm(A, Bs, transposed = trans)
# mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
# microbenchmark::microbenchmark(
# mlm(A, Bs, transposed = trans),
# mlm_reference(A, Bs, transposed = trans)
# )
# microbenchmark::microbenchmark(
# mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)]),
# mlm_reference(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
# )