138 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			138 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Multi Linear Multiplication
 | 
						|
#'
 | 
						|
#' \deqn{C\times\{ B_1, ..., B_r \}}{%
 | 
						|
#'       C = A x { B1, ..., Br }}
 | 
						|
#'
 | 
						|
#' @param A tensor (multi-linear array)
 | 
						|
#' @param Bs matrix or list of matrices
 | 
						|
#' @param modes integer sequence of the same length as `Bs` specifying the
 | 
						|
#'  multiplication axis (defaults to `seq_along(Bs)`)
 | 
						|
#' @param transposed single boolean or boolean vector of same length as \code{Bs}
 | 
						|
#'  to transpose the \code{Bs} of matching index before multiplication.
 | 
						|
#'
 | 
						|
#' @examples
 | 
						|
#' # general usage
 | 
						|
#' dimA <- c(3, 17, 19, 2)
 | 
						|
#' dimC <- c(7, 11, 13, 5)
 | 
						|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
 | 
						|
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimC, dimA)
 | 
						|
#' C1 <- mlm(A, Bs)
 | 
						|
#' C2 <- mlm(A, Bs)
 | 
						|
#' C3 <- mlm(A, Bs[c(3, 1, 2, 4)], modes = c(3, 1, 2, 4))
 | 
						|
#' stopifnot(all.equal(C1, C2))
 | 
						|
#' stopifnot(all.equal(C1, C3))
 | 
						|
#'
 | 
						|
#' # selected modes
 | 
						|
#' stopifnot(all.equal(
 | 
						|
#'     mlm(A, Bs[2:3], modes = 2:3),
 | 
						|
#'     ttm(ttm(A, Bs[[2]], 2), Bs[[3]], 3)
 | 
						|
#' ))
 | 
						|
#'
 | 
						|
#' # analog to matrix multiplication
 | 
						|
#' A <- matrix(rnorm( 6), 2, 3)
 | 
						|
#' B <- matrix(rnorm(12), 3, 4)
 | 
						|
#' C <- matrix(rnorm(20), 5, 4)
 | 
						|
#' stopifnot(all.equal(
 | 
						|
#'     A %*% B %*% t(C),
 | 
						|
#'     mlm(B, list(A, C))
 | 
						|
#' ))
 | 
						|
#'
 | 
						|
#' # usage of transposed
 | 
						|
#' A <- matrix(rnorm( 6), 2, 3)
 | 
						|
#' B <- matrix(rnorm(15), 3, 5)
 | 
						|
#' C <- matrix(rnorm(35), 5, 7)
 | 
						|
#'
 | 
						|
#' stopifnot(all.equal(
 | 
						|
#'     A %*% B %*% C,
 | 
						|
#'     mlm(B, list(A, C), transposed = c(FALSE, TRUE))
 | 
						|
#' ))
 | 
						|
#'
 | 
						|
#' # usage with repeated modes (non commutative)
 | 
						|
#' dimA <- c(3, 17, 19, 2)
 | 
						|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
 | 
						|
#' B1 <- matrix(rnorm(9), 3, 3)
 | 
						|
#' B2 <- matrix(rnorm(9), 3, 3)
 | 
						|
#' C <- matrix(rnorm(4), 2, 2)
 | 
						|
#' # same modes do NOT commute
 | 
						|
#' all.equal(
 | 
						|
#'     mlm(A, list(B1, B2, C), c(1, 1, 4)),  # NOT equal!
 | 
						|
#'     mlm(A, list(B2, B1, C), c(1, 1, 4))
 | 
						|
#' )
 | 
						|
#' # but different modes do commute
 | 
						|
#' P1 <- mlm(A, list(C, B1, B2), c(4, 1, 1))
 | 
						|
#' P2 <- mlm(A, list(B1, C, B2), c(1, 4, 1))
 | 
						|
#' P3 <- mlm(A, list(B1, B2, C), c(1, 1, 4))
 | 
						|
#' stopifnot(all.equal(P1, P2))
 | 
						|
#' stopifnot(all.equal(P1, P3))
 | 
						|
#'
 | 
						|
#' # Concatination of MLM is MLM
 | 
						|
#' dimX <- c(4, 8, 6, 3)
 | 
						|
#' dimA <- c(3, 17, 19, 2)
 | 
						|
#' dimB <- c(7, 11, 13, 5)
 | 
						|
#' X <- array(rnorm(prod(dimX)), dim = dimX)
 | 
						|
#' As <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimA, dimX)
 | 
						|
#' Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), dimB, dimA)
 | 
						|
#' # (X x {A1, A2, A3, A4}) x {B1, B2, B3, B4} = X x {B1 A1, B2 A2, B3 A3, B4 A4}
 | 
						|
#' all.equal(mlm(mlm(X, As), Bs), mlm(X, Map(`%*%`, Bs, As)))
 | 
						|
#'
 | 
						|
#' # Equivalent to
 | 
						|
#' mlm_reference <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
 | 
						|
#'     # Collect all matrices in `B`
 | 
						|
#'     Bs <- if (is.matrix(Bs)) list(Bs) else Bs
 | 
						|
#'
 | 
						|
#'     # replicate transposition if of length one only
 | 
						|
#'     transposed <- if (length(transposed) == 1) {
 | 
						|
#'         rep(as.logical(transposed), length(Bs))
 | 
						|
#'     } else {
 | 
						|
#'         as.logical(transposed)
 | 
						|
#'     }
 | 
						|
#'
 | 
						|
#'     # iteratively apply Tensor Times Matrix multiplication over modes
 | 
						|
#'     for (i in seq_along(modes)) {
 | 
						|
#'         A <- ttm(A, Bs[[i]], modes[i], transposed[i])
 | 
						|
#'     }
 | 
						|
#'
 | 
						|
#'     # return result tensor
 | 
						|
#'     A
 | 
						|
#' }
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
mlm <- function(A, Bs, modes = seq_along(Bs), transposed = FALSE) {
 | 
						|
    # Collect all matrices in `B`
 | 
						|
    Bs <- if (!is.list(Bs)) list(Bs) else Bs
 | 
						|
    # ensure all `B`s are matrices
 | 
						|
    Bs <- Map(as.matrix, Bs)
 | 
						|
 | 
						|
    # replicate transposition if of length one only
 | 
						|
    transposed <- if (length(transposed) == 1) {
 | 
						|
        rep(as.logical(transposed), length(Bs))
 | 
						|
    } else if (length(transposed) == length(modes)) {
 | 
						|
        as.logical(transposed)
 | 
						|
    } else {
 | 
						|
        stop("Dim missmatch of param. `transposed`")
 | 
						|
    }
 | 
						|
 | 
						|
    .Call("C_mlm", A, Bs, as.integer(modes), transposed, PACKAGE = "tensorPredictors")
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
# # general usage
 | 
						|
# dimA <- c(3, 17, 19, 2)
 | 
						|
# dimC <- c(7, 11, 13, 5)
 | 
						|
# A <- array(rnorm(prod(dimA)), dim = dimA)
 | 
						|
# trans <- c(TRUE, FALSE, TRUE, FALSE)
 | 
						|
# Bs <- Map(function(p, q) matrix(rnorm(p * q), p, q), ifelse(trans, dimA, dimC), ifelse(trans, dimC, dimA))
 | 
						|
 | 
						|
# C <- mlm(A, Bs, transposed = trans)
 | 
						|
# mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
 | 
						|
 | 
						|
# microbenchmark::microbenchmark(
 | 
						|
#     mlm(A, Bs, transposed = trans),
 | 
						|
#     mlm_reference(A, Bs, transposed = trans)
 | 
						|
# )
 | 
						|
 | 
						|
# microbenchmark::microbenchmark(
 | 
						|
#     mlm(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)]),
 | 
						|
#     mlm_reference(A, Bs[c(3, 2)], modes = c(3, 2), transposed = trans[c(3, 2)])
 | 
						|
# )
 |