82 lines
2.2 KiB
C

#include "merge_matmul.h"
void merge_matmul(
const int nrowA, const int ncolA, const int nslices,
const int nrowB, const int ncolB,
const double alpha,
const double* A, const double* B,
const double beta,
double* C
) {
for (int k = 0; k < nslices; ++k) {
F77_CALL(dgemm)("N", "N", &nrowA, &ncolB, &ncolA,
&alpha,
&A[k * nrowA * ncolA], &nrowA,
&B[k * nrowB * ncolB], &nrowB,
&beta,
&C[k * nrowA * ncolB], &nrowA
FCONE FCONE);
}
}
/**
* Apply matrix-matrix multiplication over 3rd axis of 3D arrays
*
* @param A 3D array
* @param B 3D array
* @returns 3D array
*/
extern SEXP R_merge_matmul(SEXP A, SEXP B) {
// Check if both are 3D arrays
if (!(Rf_isArray(A) && Rf_isArray(B))) {
Rf_error("Both A and B must be 3D arrays");
}
// get object dimensions
SEXP dimA = Rf_getAttrib(A, R_DimSymbol);
SEXP dimB = Rf_getAttrib(B, R_DimSymbol);
// Check dimension compatibility
if ((Rf_length(dimA) != 3) || (Rf_length(dimB) != 3)) {
Rf_error("Both A and B must be 3D arrays");
}
// Extract dimensions
int nrowA = INTEGER(dimA)[0];
int ncolA = INTEGER(dimA)[1];
int nslices = INTEGER(dimA)[2];
int nrowB = INTEGER(dimB)[0];
int ncolB = INTEGER(dimB)[1];
// Validate dimension compatibility
if (!nrowA | !ncolA | !nslices | !nrowB | !ncolB) {
Rf_error("Zero dimension detected");
}
if (nslices != INTEGER(dimB)[2]) {
Rf_error("Dimension misspatch dim(A)[3] != dim(B)[3]");
}
if (ncolA != nrowB) {
Rf_error("Dimension misspatch ncol(A) != nrow(B)");
}
// create response object C
int sizeC = nrowA * ncolB * nslices;
SEXP C = PROTECT(Rf_allocVector(REALSXP, sizeC));
SEXP dimC = PROTECT(Rf_allocVector(INTSXP, 3));
INTEGER(dimC)[0] = nrowA;
INTEGER(dimC)[1] = ncolB;
INTEGER(dimC)[2] = nslices;
Rf_setAttrib(C, R_DimSymbol, dimC);
// Apply matxir-matrix multiplication over 3rd array axis
merge_matmul(nrowA, ncolA, nslices, nrowB, ncolB,
1.0, REAL(A), REAL(B), 0.0, REAL(C)
);
// release C to the garbage collector
UNPROTECT(2);
return C;
}