82 lines
2.5 KiB
R
82 lines
2.5 KiB
R
#' Multi Linear Multiplication
|
|
#'
|
|
#' C = A x { B1, ..., Br }
|
|
#'
|
|
#' @param A tensor (multi-linear array)
|
|
#' @param B matrix or list of matrices
|
|
#' @param ... further matrices, concatenated with \code{B}
|
|
#' @param modes integer sequence of the same length as number of matrices
|
|
#' supplied (in \code{B} and \code{...})
|
|
#'
|
|
#' @examples
|
|
#' # general usage
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
#' dimC <- c(7, 11, 13, 5)
|
|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
|
|
#' B <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimC, dimA)
|
|
#' C1 <- mlm(A, B)
|
|
#' C2 <- mlm(A, B[[1]], B[[2]], B[[3]], B[[4]])
|
|
#' C3 <- mlm(A, B[[3]], B[[1]], B[[2]], B[[4]], modes = c(3, 1, 2, 4))
|
|
#' C4 <- mlm(A, B[1:3], B[[4]])
|
|
#' stopifnot(all.equal(C1, C2))
|
|
#' stopifnot(all.equal(C1, C3))
|
|
#' stopifnot(all.equal(C1, C4))
|
|
#'
|
|
#' # selected modes
|
|
#' C1 <- mlm(A, B[2:3], modes = 2:3)
|
|
#' C2 <- mlm(A, B[[2]], B[[3]], modes = 2:3)
|
|
#' C3 <- ttm(ttm(A, B[[2]], 2), B[[3]], 3)
|
|
#' stopifnot(all.equal(C1, C2))
|
|
#' stopifnot(all.equal(C1, C3))
|
|
#'
|
|
#' # analog to matrix multiplication
|
|
#' A <- matrix(rnorm( 6), 2, 3)
|
|
#' B <- matrix(rnorm(12), 3, 4)
|
|
#' C <- matrix(rnorm(20), 5, 4)
|
|
#' stopifnot(all.equal(
|
|
#' A %*% B %*% t(C),
|
|
#' mlm(B, list(A, C))
|
|
#' ))
|
|
#'
|
|
#' # usage with repeated modes (non commutative)
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
|
|
#' B1 <- matrix(rnorm(9), 3, 3)
|
|
#' B2 <- matrix(rnorm(9), 3, 3)
|
|
#' C <- matrix(rnorm(4), 2, 2)
|
|
#' # same modes do NOT commute
|
|
#' all.equal(
|
|
#' mlm(A, B1, B2, C, modes = c(1, 1, 4)), # NOT equal!
|
|
#' mlm(A, B2, B1, C, modes = c(1, 1, 4))
|
|
#' )
|
|
#' # but different modes do commute
|
|
#' P1 <- mlm(A, C, B1, B2, modes = c(4, 1, 1))
|
|
#' P2 <- mlm(A, B1, C, B2, modes = c(1, 4, 1))
|
|
#' P3 <- mlm(A, B1, B2, C, modes = c(1, 1, 4))
|
|
#' stopifnot(all.equal(P1, P2))
|
|
#' stopifnot(all.equal(P1, P3))
|
|
#'
|
|
#' Concatination of MLM is MLM
|
|
#' dimX <- c(4, 8, 6, 3)
|
|
#' dimA <- c(3, 17, 19, 2)
|
|
#' dimB <- c(7, 11, 13, 5)
|
|
#' X <- array(rnorm(prod(dimX)), dim = dimX)
|
|
#' As <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimA, dimX)
|
|
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimB, dimA)
|
|
#' # (X x {A1, A2, A3, A4}) x {B1, B2, B3, B4} = X x {B1 A1, B2 A2, B3 A3, B4 A4}
|
|
#' all.equal(mlm(mlm(X, As), Bs), mlm(X, Map(`%*%`, Bs, As)))
|
|
#'
|
|
#' @export
|
|
mlm <- function(A, B, ..., modes = seq_along(B)) {
|
|
# Collect all matrices in `B`
|
|
B <- c(if (is.matrix(B)) list(B) else B, list(...))
|
|
|
|
# iteratively apply Tensor Times Matrix multiplication over modes
|
|
for (i in seq_along(modes)) {
|
|
A <- ttm(A, B[[i]], modes[i])
|
|
}
|
|
|
|
# return result tensor
|
|
A
|
|
}
|