#include "merge_matmul.h" void merge_matmul( const int nrowA, const int ncolA, const int nslices, const int nrowB, const int ncolB, const double alpha, const double* A, const double* B, const double beta, double* C ) { for (int k = 0; k < nslices; ++k) { F77_CALL(dgemm)("N", "N", &nrowA, &ncolB, &ncolA, &alpha, &A[k * nrowA * ncolA], &nrowA, &B[k * nrowB * ncolB], &nrowB, &beta, &C[k * nrowA * ncolB], &nrowA FCONE FCONE); } } /** * Apply matrix-matrix multiplication over 3rd axis of 3D arrays * * @param A 3D array * @param B 3D array * @returns 3D array */ extern SEXP R_merge_matmul(SEXP A, SEXP B) { // Check if both are 3D arrays if (!(Rf_isArray(A) && Rf_isArray(B))) { Rf_error("Both A and B must be 3D arrays"); } // get object dimensions SEXP dimA = Rf_getAttrib(A, R_DimSymbol); SEXP dimB = Rf_getAttrib(B, R_DimSymbol); // Check dimension compatibility if ((Rf_length(dimA) != 3) || (Rf_length(dimB) != 3)) { Rf_error("Both A and B must be 3D arrays"); } // Extract dimensions int nrowA = INTEGER(dimA)[0]; int ncolA = INTEGER(dimA)[1]; int nslices = INTEGER(dimA)[2]; int nrowB = INTEGER(dimB)[0]; int ncolB = INTEGER(dimB)[1]; // Validate dimension compatibility if (!nrowA | !ncolA | !nslices | !nrowB | !ncolB) { Rf_error("Zero dimension detected"); } if (nslices != INTEGER(dimB)[2]) { Rf_error("Dimension misspatch dim(A)[3] != dim(B)[3]"); } if (ncolA != nrowB) { Rf_error("Dimension misspatch ncol(A) != nrow(B)"); } // create response object C int sizeC = nrowA * ncolB * nslices; SEXP C = PROTECT(Rf_allocVector(REALSXP, sizeC)); SEXP dimC = PROTECT(Rf_allocVector(INTSXP, 3)); INTEGER(dimC)[0] = nrowA; INTEGER(dimC)[1] = ncolB; INTEGER(dimC)[2] = nslices; Rf_setAttrib(C, R_DimSymbol, dimC); // Apply matxir-matrix multiplication over 3rd array axis merge_matmul(nrowA, ncolA, nslices, nrowB, ncolB, 1.0, REAL(A), REAL(B), 0.0, REAL(C) ); // release C to the garbage collector UNPROTECT(2); return C; }