82 lines
2.2 KiB
C
82 lines
2.2 KiB
C
#include "merge_matmul.h"
|
|
|
|
void merge_matmul(
|
|
const int nrowA, const int ncolA, const int nslices,
|
|
const int nrowB, const int ncolB,
|
|
const double alpha,
|
|
const double* A, const double* B,
|
|
const double beta,
|
|
double* C
|
|
) {
|
|
for (int k = 0; k < nslices; ++k) {
|
|
F77_CALL(dgemm)("N", "N", &nrowA, &ncolB, &ncolA,
|
|
&alpha,
|
|
&A[k * nrowA * ncolA], &nrowA,
|
|
&B[k * nrowB * ncolB], &nrowB,
|
|
&beta,
|
|
&C[k * nrowA * ncolB], &nrowA
|
|
FCONE FCONE);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Apply matrix-matrix multiplication over 3rd axis of 3D arrays
|
|
*
|
|
* @param A 3D array
|
|
* @param B 3D array
|
|
* @returns 3D array
|
|
*/
|
|
extern SEXP R_merge_matmul(SEXP A, SEXP B) {
|
|
|
|
// Check if both are 3D arrays
|
|
if (!(Rf_isArray(A) && Rf_isArray(B))) {
|
|
Rf_error("Both A and B must be 3D arrays");
|
|
}
|
|
|
|
// get object dimensions
|
|
SEXP dimA = Rf_getAttrib(A, R_DimSymbol);
|
|
SEXP dimB = Rf_getAttrib(B, R_DimSymbol);
|
|
|
|
// Check dimension compatibility
|
|
if ((Rf_length(dimA) != 3) || (Rf_length(dimB) != 3)) {
|
|
Rf_error("Both A and B must be 3D arrays");
|
|
}
|
|
|
|
// Extract dimensions
|
|
int nrowA = INTEGER(dimA)[0];
|
|
int ncolA = INTEGER(dimA)[1];
|
|
int nslices = INTEGER(dimA)[2];
|
|
int nrowB = INTEGER(dimB)[0];
|
|
int ncolB = INTEGER(dimB)[1];
|
|
|
|
// Validate dimension compatibility
|
|
if (!nrowA | !ncolA | !nslices | !nrowB | !ncolB) {
|
|
Rf_error("Zero dimension detected");
|
|
}
|
|
if (nslices != INTEGER(dimB)[2]) {
|
|
Rf_error("Dimension misspatch dim(A)[3] != dim(B)[3]");
|
|
}
|
|
if (ncolA != nrowB) {
|
|
Rf_error("Dimension misspatch ncol(A) != nrow(B)");
|
|
}
|
|
|
|
// create response object C
|
|
int sizeC = nrowA * ncolB * nslices;
|
|
SEXP C = PROTECT(Rf_allocVector(REALSXP, sizeC));
|
|
SEXP dimC = PROTECT(Rf_allocVector(INTSXP, 3));
|
|
INTEGER(dimC)[0] = nrowA;
|
|
INTEGER(dimC)[1] = ncolB;
|
|
INTEGER(dimC)[2] = nslices;
|
|
Rf_setAttrib(C, R_DimSymbol, dimC);
|
|
|
|
// Apply matxir-matrix multiplication over 3rd array axis
|
|
merge_matmul(nrowA, ncolA, nslices, nrowB, ncolB,
|
|
1.0, REAL(A), REAL(B), 0.0, REAL(C)
|
|
);
|
|
|
|
// release C to the garbage collector
|
|
UNPROTECT(2);
|
|
|
|
return C;
|
|
}
|