35 lines
872 B
R
35 lines
872 B
R
#' Merge matrix-matrix multiplication over 3-axis of 3D arrays
|
|
#'
|
|
#' @param A 3D numeric array
|
|
#' @param B 3D numeric array
|
|
#' @returns 3D numeric array
|
|
#'
|
|
#' @examples
|
|
#' # Equivalent to the reference implementation
|
|
#' merge.matmul.reference <- function(A, B) {
|
|
#' C <- array(dim = c(nrow(A), ncol(B), dim(A)[3]))
|
|
#' for (i in seq_len(dim(A)[3])) {
|
|
#' C[, , i] <- A[, , i] %*% B[, , i]
|
|
#' }
|
|
#' C
|
|
#' }
|
|
#'
|
|
#' dimA <- c(3, 5, 101)
|
|
#' dimB <- c(5, 2, 101)
|
|
#' A <- array(rnorm(prod(dimA)), dim = dimA)
|
|
#' B <- array(rnorm(prod(dimB)), dim = dimB)
|
|
#'
|
|
#' C <- merge.matmul(A, B)
|
|
#' dim(C) # c(3, 2, 101)
|
|
#'
|
|
#' all.equal(
|
|
#' merge.matmul.reference(A, B),
|
|
#' merge.matmul(A, B)
|
|
#' )
|
|
#'
|
|
#' @export
|
|
merge.matmul <- function(A, B) {
|
|
storage.mode(A) <- storage.mode(B) <- "double"
|
|
.Call("C_merge_matmul", A, B, PACKAGE = "tensorPredictors")
|
|
}
|