forgot to add

This commit is contained in:
Daniel Kapla 2022-05-06 22:28:08 +02:00
parent b8064c90ef
commit 36dd08c7c9
3 changed files with 16 additions and 10 deletions

View File

@ -13,6 +13,7 @@ export(approx.kronecker)
export(colKronecker) export(colKronecker)
export(dist.projection) export(dist.projection)
export(dist.subspace) export(dist.subspace)
export(kpir.approx)
export(kpir.base) export(kpir.base)
export(kpir.kron) export(kpir.kron)
export(kpir.momentum) export(kpir.momentum)
@ -20,6 +21,7 @@ export(kpir.new)
export(mat) export(mat)
export(matpow) export(matpow)
export(matrixImage) export(matrixImage)
export(mcrossprod)
export(reduce) export(reduce)
export(rowKronecker) export(rowKronecker)
export(tensor_predictor) export(tensor_predictor)

View File

@ -9,10 +9,14 @@
/* Tensor Times Matrix a.k.a. Mode Product */ /* Tensor Times Matrix a.k.a. Mode Product */
extern SEXP ttm(SEXP A, SEXP X, SEXP mode); extern SEXP ttm(SEXP A, SEXP X, SEXP mode);
/* Tensor Mode Covariance e.g. `(1 / n) * A_(m) A_(m)^T` */
extern SEXP mcrossprod(SEXP A, SEXP mode);
/* List of registered routines (e.g. C entry points) */ /* List of registered routines (e.g. C entry points) */
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
// {"FastPOI_C_sub", (DL_FUNC) &FastPOI_C_sub, 5}, // NOT USED // {"FastPOI_C_sub", (DL_FUNC) &FastPOI_C_sub, 5}, // NOT USED
{"C_ttm", (DL_FUNC) &ttm, 3}, {"C_ttm", (DL_FUNC) &ttm, 3},
{"C_mcrossprod", (DL_FUNC) &mcrossprod, 2},
{NULL, NULL, 0} {NULL, NULL, 0}
}; };

View File

@ -11,7 +11,7 @@
/** /**
* Tensor Times Matrix a.k.a. Mode Product * Tensor Times Matrix a.k.a. Mode Product
* *
* @param A multi-dimensionl array * @param A multi-dimensional array
* @param B matrix * @param B matrix
* @param m mode index (1-indexed) * @param m mode index (1-indexed)
*/ */
@ -41,12 +41,12 @@ extern SEXP ttm(SEXP A, SEXP B, SEXP m) {
error("Dimension missmatch (mode dim not equal to ncol)"); error("Dimension missmatch (mode dim not equal to ncol)");
} }
// calc nr of response elements `prod(dim(X)[-mode]) * ncol(X)`, // calc nr of response elements `prod(dim(A)[-mode]) * ncol(A)` (size of C),
int leny = 1; int sizeC = 1;
// and the strides // and the strides
// `stride[0] <- prod(dim(X)[seq_len(mode - 1)])` // `stride[0] <- prod(dim(A)[seq_len(mode - 1)])`
// `stride[1] <- dim(X)[mode]` // `stride[1] <- dim(A)[mode]`
// `stride[2] <- prod(dim(X)[-seq_len(mode)])` // `stride[2] <- prod(dim(A)[-seq_len(mode)])`
int stride[3] = {1, INTEGER(dim)[mode], 1}; int stride[3] = {1, INTEGER(dim)[mode], 1};
for (int i = 0; i < length(dim); ++i) { for (int i = 0; i < length(dim); ++i) {
int size = INTEGER(dim)[i]; int size = INTEGER(dim)[i];
@ -54,7 +54,7 @@ extern SEXP ttm(SEXP A, SEXP B, SEXP m) {
if (!size) { if (!size) {
error("Zero dimension detected"); error("Zero dimension detected");
} }
leny *= (i == mode) ? nrows(B) : size; sizeC *= (i == mode) ? nrows(B) : size;
stride[0] *= (i < mode) ? size : 1; stride[0] *= (i < mode) ? size : 1;
stride[2] *= (i > mode) ? size : 1; stride[2] *= (i > mode) ? size : 1;
} }
@ -63,7 +63,7 @@ extern SEXP ttm(SEXP A, SEXP B, SEXP m) {
int nrow = nrows(B); int nrow = nrows(B);
// create response object C // create response object C
SEXP C = PROTECT(allocVector(REALSXP, leny)); SEXP C = PROTECT(allocVector(REALSXP, sizeC));
// raw data access pointers // raw data access pointers
double* a = REAL(A); double* a = REAL(A);
@ -88,7 +88,7 @@ extern SEXP ttm(SEXP A, SEXP B, SEXP m) {
} }
/* /*
// Tensor Times Matrix / Mode Product (reference implementation) // Tensor Times Matrix / Mode Product (reference implementation)
memset(c, 0, leny * sizeof(double)); memset(c, 0, sizeC * sizeof(double));
for (int i2 = 0; i2 < stride[2]; ++i2) { for (int i2 = 0; i2 < stride[2]; ++i2) {
for (int i1 = 0; i1 < stride[1]; ++i1) { // stride[1] == ncols(B) for (int i1 = 0; i1 < stride[1]; ++i1) { // stride[1] == ncols(B)
for (int j = 0; j < nrow; ++j) { for (int j = 0; j < nrow; ++j) {