2022-05-11 15:26:37 +00:00
|
|
|
#' Multi Linear Multiplication
|
|
|
|
#'
|
2022-10-06 12:25:40 +00:00
|
|
|
#' \deqn{C\times\{ B_1, ..., B_r \}}{%
|
|
|
|
#' C = A x { B1, ..., Br }}
|
2022-05-11 15:26:37 +00:00
|
|
|
#'
|
|
|
|
#' @param A tensor (multi-linear array)
|
2022-10-06 12:25:40 +00:00
|
|
|
#' @param Bs matrix or list of matrices
|
|
|
|
#' @param modes integer sequence of the same length as `Bs` specifying the
|
|
|
|
#' multiplication axis (defaults to `seq_along(Bs)`)
|
|
|
|
#' @param transposed single boolean or boolean vector of same length as \code{Bs}
|
|
|
|
#' to transpose the \code{Bs} of matching index before multiplication.
|
2022-05-11 15:26:37 +00:00
|
|
|
#'
|
|
|
|
#' @examples
|
|
|
|
#' # general usage
|
|
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
|
|
#' dimC <- c(7, 11, 13, 5)
|
|
|
|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
|
2022-10-06 12:25:40 +00:00
|
|
|
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimC, dimA)
|
|
|
|
#' C1 <- mlm(A, Bs)
|
|
|
|
#' C2 <- mlm(A, Bs)
|
|
|
|
#' C3 <- mlm(A, Bs[c(3, 1, 2, 4)], modes = c(3, 1, 2, 4))
|
2022-05-11 15:26:37 +00:00
|
|
|
#' stopifnot(all.equal(C1, C2))
|
|
|
|
#' stopifnot(all.equal(C1, C3))
|
|
|
|
#'
|
|
|
|
#' # selected modes
|
2022-10-06 12:25:40 +00:00
|
|
|
#' stopifnot(all.equal(
|
|
|
|
#' mlm(A, Bs[2:3], modes = 2:3),
|
|
|
|
#' ttm(ttm(A, Bs[[2]], 2), Bs[[3]], 3)
|
|
|
|
#' ))
|
2022-05-11 15:26:37 +00:00
|
|
|
#'
|
|
|
|
#' # analog to matrix multiplication
|
|
|
|
#' A <- matrix(rnorm( 6), 2, 3)
|
|
|
|
#' B <- matrix(rnorm(12), 3, 4)
|
2022-05-27 18:11:48 +00:00
|
|
|
#' C <- matrix(rnorm(20), 5, 4)
|
2022-05-11 15:26:37 +00:00
|
|
|
#' stopifnot(all.equal(
|
|
|
|
#' A %*% B %*% t(C),
|
|
|
|
#' mlm(B, list(A, C))
|
|
|
|
#' ))
|
|
|
|
#'
|
2022-10-06 12:25:40 +00:00
|
|
|
#' # usage of transposed
|
|
|
|
#' A <- matrix(rnorm( 6), 2, 3)
|
|
|
|
#' B <- matrix(rnorm(15), 3, 5)
|
|
|
|
#' C <- matrix(rnorm(35), 5, 7)
|
|
|
|
#'
|
|
|
|
#' stopifnot(all.equal(
|
|
|
|
#' A %*% B %*% C,
|
|
|
|
#' mlm(B, list(A, C), transposed = c(FALSE, TRUE))
|
|
|
|
#' ))
|
|
|
|
#'
|
2022-05-11 15:26:37 +00:00
|
|
|
#' # usage with repeated modes (non commutative)
|
|
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
|
|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
|
|
|
|
#' B1 <- matrix(rnorm(9), 3, 3)
|
|
|
|
#' B2 <- matrix(rnorm(9), 3, 3)
|
|
|
|
#' C <- matrix(rnorm(4), 2, 2)
|
|
|
|
#' # same modes do NOT commute
|
|
|
|
#' all.equal(
|
2022-10-06 12:25:40 +00:00
|
|
|
#' mlm(A, list(B1, B2, C), c(1, 1, 4)), # NOT equal!
|
|
|
|
#' mlm(A, list(B2, B1, C), c(1, 1, 4))
|
2022-05-11 15:26:37 +00:00
|
|
|
#' )
|
|
|
|
#' # but different modes do commute
|
2022-10-06 12:25:40 +00:00
|
|
|
#' P1 <- mlm(A, list(C, B1, B2), c(4, 1, 1))
|
|
|
|
#' P2 <- mlm(A, list(B1, C, B2), c(1, 4, 1))
|
|
|
|
#' P3 <- mlm(A, list(B1, B2, C), c(1, 1, 4))
|
2022-05-11 15:26:37 +00:00
|
|
|
#' stopifnot(all.equal(P1, P2))
|
|
|
|
#' stopifnot(all.equal(P1, P3))
|
|
|
|
#'
|
2022-10-06 12:25:40 +00:00
|
|
|
#' # Concatination of MLM is MLM
|
2022-05-27 18:11:48 +00:00
|
|
|
#' dimX <- c(4, 8, 6, 3)
|
|
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
|
|
#' dimB <- c(7, 11, 13, 5)
|
|
|
|
#' X <- array(rnorm(prod(dimX)), dim = dimX)
|
|
|
|
#' As <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimA, dimX)
|
|
|
|
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimB, dimA)
|
|
|
|
#' # (X x {A1, A2, A3, A4}) x {B1, B2, B3, B4} = X x {B1 A1, B2 A2, B3 A3, B4 A4}
|
|
|
|
#' all.equal(mlm(mlm(X, As), Bs), mlm(X, Map(`%*%`, Bs, As)))
|
|
|
|
#'
|
2023-11-14 13:35:43 +00:00
|
|
|
#' # Equivalent to
|
|
|
|
#' mlm_reference <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
|
|
|
|
#' # Collect all matrices in `B`
|
|
|
|
#' Bs <- if (is.matrix(Bs)) list(Bs) else Bs
|
|
|
|
#'
|
|
|
|
#' # replicate transposition if of length one only
|
|
|
|
#' transposed <- if (length(transposed) == 1) {
|
|
|
|
#' rep(as.logical(transposed), length(Bs))
|
|
|
|
#' } else {
|
|
|
|
#' as.logical(transposed)
|
|
|
|
#' }
|
|
|
|
#'
|
|
|
|
#' # iteratively apply Tensor Times Matrix multiplication over modes
|
|
|
|
#' for (i in seq_along(modes)) {
|
|
|
|
#' A <- ttm(A, Bs[[i]], modes[i], transposed[i])
|
|
|
|
#' }
|
|
|
|
#'
|
|
|
|
#' # return result tensor
|
|
|
|
#' A
|
|
|
|
#' }
|
|
|
|
#'
|
2022-05-11 15:26:37 +00:00
|
|
|
#' @export
|
2022-10-06 12:25:40 +00:00
|
|
|
mlm <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
|
2022-05-11 15:26:37 +00:00
|
|
|
# Collect all matrices in `B`
|
2023-11-14 13:35:43 +00:00
|
|
|
Bs <- if (!is.list(Bs)) list(Bs) else Bs
|
|
|
|
# ensure all `B`s are matrices
|
|
|
|
Bs <- Map(as.matrix, Bs)
|
2022-10-06 12:25:40 +00:00
|
|
|
|
|
|
|
# replicate transposition if of length one only
|
|
|
|
transposed <- if (length(transposed) == 1) {
|
|
|
|
rep(as.logical(transposed), length(Bs))
|
2023-11-14 13:35:43 +00:00
|
|
|
} else if (length(transposed) == length(modes)) {
|
2022-10-06 12:25:40 +00:00
|
|
|
as.logical(transposed)
|
2023-11-14 13:35:43 +00:00
|
|
|
} else {
|
|
|
|
stop("Dim missmatch of param. `transposed`")
|
2022-10-06 12:25:40 +00:00
|
|
|
}
|
2022-05-11 15:26:37 +00:00
|
|
|
|
2023-11-14 13:35:43 +00:00
|
|
|
.Call("C_mlm", A, Bs, as.integer(modes), transposed, PACKAGE = "tensorPredictors")
|
2022-05-11 15:26:37 +00:00
|
|
|
}
|
2023-11-14 13:35:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
# # general usage
|
|
|
|
# dimA <- c(3, 17, 19, 2)
|
|
|
|
# dimC <- c(7, 11, 13, 5)
|
|
|
|
# A <- array(rnorm(prod(dimA)), dim = dimA)
|
|
|
|
# trans <- c(TRUE, FALSE, TRUE, FALSE)
|
|
|
|
# Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), ifelse(trans, dimA, dimC), ifelse(trans, dimC, dimA))
|
|
|
|
|
|
|
|
# C <- mlm(A, Bs, transposed = trans)
|
|
|
|
# mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
|
|
|
|
|
|
|
|
# microbenchmark::microbenchmark(
|
|
|
|
# mlm(A, Bs, transposed = trans),
|
|
|
|
# mlm_reference(A, Bs, transposed = trans)
|
|
|
|
# )
|
|
|
|
|
|
|
|
# microbenchmark::microbenchmark(
|
|
|
|
# mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)]),
|
|
|
|
# mlm_reference(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
|
|
|
|
# )
|