tensor_predictors/tensorPredictors/R/merge_matmul.R

35 lines
872 B
R

#' Merge matrix-matrix multiplication over 3-axis of 3D arrays
#'
#' @param A 3D numeric array
#' @param B 3D numeric array
#' @returns 3D numeric array
#'
#' @examples
#' # Equivalent to the reference implementation
#' merge.matmul.reference <- function(A, B) {
#' C <- array(dim = c(nrow(A), ncol(B), dim(A)[3]))
#' for (i in seq_len(dim(A)[3])) {
#' C[, , i] <- A[, , i] %*% B[, , i]
#' }
#' C
#' }
#'
#' dimA <- c(3, 5, 101)
#' dimB <- c(5, 2, 101)
#' A <- array(rnorm(prod(dimA)), dim = dimA)
#' B <- array(rnorm(prod(dimB)), dim = dimB)
#'
#' C <- merge.matmul(A, B)
#' dim(C) # c(3, 2, 101)
#'
#' all.equal(
#' merge.matmul.reference(A, B),
#' merge.matmul(A, B)
#' )
#'
#' @export
merge.matmul <- function(A, B) {
storage.mode(A) <- storage.mode(B) <- "double"
.Call("C_merge_matmul", A, B, PACKAGE = "tensorPredictors")
}