tensor_predictors/tensorPredictors/src/mcrossprod.c

89 lines
2.7 KiB
C

// The need for `USE_FC_LEN_T` and `FCONE` is due to a Fortran character string
// to C incompatibility. See: Writing R Extentions: 6.6.1 Fortran character strings
#define USE_FC_LEN_T
#include <R.h>
#include <Rinternals.h>
#include <R_ext/BLAS.h>
#ifndef FCONE
#define FCONE
#endif
/**
* Tensor Mode Crossproduct
*
* C = A_(m) t(A_(m))
*
* For a matrix `A`, the first mode is `mcrossprod(A, 1)` equivalent to
* `A %*% t(A)` (`tcrossprod`). On the other hand for mode two `mcrossprod(A, 2)`
* the equivalence is `t(A) %*% A` (`crossprod`).
*
* @param A multi-dimensional array
* @param m mode index (1-indexed)
*/
extern SEXP mcrossprod(SEXP A, SEXP m) {
// get zero indexed mode
int mode = asInteger(m) - 1;
// get dimension attribute of A
SEXP dim = getAttrib(A, R_DimSymbol);
// validate mode (0-indexed, must be smaller than the tensor order)
if (mode < 0 || length(dim) <= mode) {
error("Illegal mode");
}
// the strides
// `stride[0] <- prod(dim(X)[seq_len(mode - 1)])`
// `stride[1] <- dim(X)[mode]`
// `stride[2] <- prod(dim(X)[-seq_len(mode)])`
int stride[3] = {1, INTEGER(dim)[mode], 1};
for (int i = 0; i < length(dim); ++i) {
int size = INTEGER(dim)[i];
stride[0] *= (i < mode) ? size : 1;
stride[2] *= (i > mode) ? size : 1;
}
// create response matrix C
SEXP C = PROTECT(allocMatrix(REALSXP, stride[1], stride[1]));
// raw data access pointers
double* a = REAL(A);
double* c = REAL(C);
// employ BLAS dsyrk (Double SYmmeric Rank K) operation
// (C = alpha A A^T + beta C or C = alpha A^T A + beta C)
const double zero = 0.0;
const double one = 1.0;
if (mode == 0) {
// mode 1: special case C = A_(1) A_(1)^T
// C = 1 A A^T + 0 C
F77_CALL(dsyrk)("U", "N", &stride[1], &stride[2],
&one, a, &stride[1], &zero, c, &stride[1] FCONE FCONE);
} else {
// Other modes writen as accumulated sum of matrix products
// initialize C to zero
memset(c, 0, stride[1] * stride[1] * sizeof(double));
// Sum over all modes > mode
for (int i2 = 0; i2 < stride[2]; ++i2) {
// C = 1 A^T A + 1 C
F77_CALL(dsyrk)("U", "T", &stride[1], &stride[0],
&one, &a[i2 * stride[0] * stride[1]], &stride[0],
&one, c, &stride[1] FCONE FCONE);
}
}
// Symmetric matrix result is stored in upper triangular part only
// Copy upper triangular part to lower
for (int j = 0; j + 1 < stride[1]; j++) {
for (int i = j + 1; i < stride[1]; ++i) {
c[i + j * stride[1]] = c[j + i * stride[1]];
}
}
// release C to grabage collector
UNPROTECT(1);
return C;
}