73 lines
2.1 KiB
R
73 lines
2.1 KiB
R
#' Tensor Times Matrix (n-mode tensor matrix product)
|
|
#'
|
|
#' @param T array of order at least \code{mode}
|
|
#' @param M matrix, the right hand side of the mode product such that
|
|
#' \code{ncol(M)} equals \code{dim(T)[mode]}
|
|
#' @param mode the mode of the product in the range \code{1:length(dim(T))}
|
|
#'
|
|
#' @returns multi-dimensional array of the same order as \code{T} with the
|
|
#' \code{mode} dimension equal to \code{nrow(M)}
|
|
#'
|
|
#' @export
|
|
ttm <- function(T, M, mode = length(dim(T))) {
|
|
mode <- as.integer(mode)
|
|
dims <- dim(T)
|
|
|
|
if (length(dims) < mode) {
|
|
stop(sprintf("Mode (%d) must be smaller equal the tensor order %d",
|
|
mode, length(dims)))
|
|
}
|
|
if (dims[mode] != ncol(M)) {
|
|
stop(sprintf("Dim. missmatch, mode %d has dim %d but ncol is %d.",
|
|
mode, dims[mode], ncol(M)))
|
|
}
|
|
|
|
# Special case of mode being equal to tensor order, then an alternative
|
|
# (but more efficient) version is Z M' where Z is only the reshaped but
|
|
# no permutation of elements is required (as in the case of mode 1)
|
|
if (mode == length(dims)) {
|
|
# Convert tensor to matrix (similar to matricization)
|
|
dim(T) <- c(prod(dims[-mode]), dims[mode])
|
|
|
|
# Equiv matrix product
|
|
C <- tcrossprod(T, M)
|
|
|
|
# Shape back to tensor
|
|
dim(C) <- c(dims[-mode], nrow(M))
|
|
|
|
C
|
|
} else {
|
|
# Matricize tensor T
|
|
if (mode != 1L) {
|
|
perm <- c(mode, seq_along(dims)[-mode])
|
|
T <- aperm(T, perm)
|
|
}
|
|
dim(T) <- c(dims[mode], prod(dims[-mode]))
|
|
|
|
# Perform equivalent matrix multiplication
|
|
C <- M %*% T
|
|
|
|
# Reshape and rearrange matricized result back to a tensor
|
|
dim(C) <- c(nrow(M), dims[-mode])
|
|
if (mode == 1L) {
|
|
C
|
|
} else {
|
|
aperm(C, order(perm))
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
#' @rdname ttm
|
|
#' @export
|
|
`%x_1%` <- function(T, M) ttm(T, M, 1L)
|
|
#' @rdname ttm
|
|
#' @export
|
|
`%x_2%` <- function(T, M) ttm(T, M, 2L)
|
|
#' @rdname ttm
|
|
#' @export
|
|
`%x_3%` <- function(T, M) ttm(T, M, 3L)
|
|
#' @rdname ttm
|
|
#' @export
|
|
`%x_4%` <- function(T, M) ttm(T, M, 4L)
|