137 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
			
		
		
	
	
			137 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
#include "ttm.h"
 | 
						|
 | 
						|
void ttm(
 | 
						|
    const int transB, const int mode,
 | 
						|
    const int* dimA, const int ordA, const int nrowB, const int ncolB,
 | 
						|
    const double alpha,
 | 
						|
    const double* A,
 | 
						|
    const double* B, const int ldB, // TODO: ldB is IGNORED!!!
 | 
						|
    const double beta,
 | 
						|
    double* C
 | 
						|
) {
 | 
						|
 | 
						|
    // Strides are the "leading" and "trailing" dimensions of the matricized
 | 
						|
    // tensor `A` in the following matrix-matrix multiplications
 | 
						|
    //  `stride[0] <- prod(dim(A)[seq_len(mode - 1)])`
 | 
						|
    //  `stride[1] <- dim(A)[mode]`
 | 
						|
    //  `stride[2] <- prod(dim(A)[-seq_len(mode)])`
 | 
						|
    int stride[3] = {1, dimA[mode], 1};
 | 
						|
    for (int i = 0; i < ordA; ++i) {
 | 
						|
        stride[0] *= (i < mode) ? dimA[i] : 1;
 | 
						|
        stride[2] *= (i > mode) ? dimA[i] : 1;
 | 
						|
    }
 | 
						|
 | 
						|
    if (mode == 0) {
 | 
						|
        // mode 1: C = alpha (A x_1 op(B))_(1) + beta C
 | 
						|
        //           = alpha op(B) A_(1) + beta C
 | 
						|
        // as a single Matrix-Matrix multiplication
 | 
						|
        F77_CALL(dgemm)(transB ? "T" : "N", "N",
 | 
						|
            (transB ? &ncolB : &nrowB), &stride[2], &stride[1], &alpha,
 | 
						|
            B, &nrowB, A, &stride[1],
 | 
						|
            &beta, C, (transB ? &ncolB : &nrowB)
 | 
						|
            FCONE FCONE);
 | 
						|
    } else {
 | 
						|
        // Other modes can be written as blocks of matrix multiplications
 | 
						|
        // C_:,:,i2 = alpha (A x_m op(B))_(m)' + beta C_:,:,i2
 | 
						|
        //          = alpha A_(m)' op(B)' + beta C_:,:,i2
 | 
						|
        for (int i2 = 0; i2 < stride[2]; ++i2) {
 | 
						|
            F77_CALL(dgemm)("N", transB ? "N" : "T",
 | 
						|
                &stride[0], (transB ? &ncolB : &nrowB), &stride[1], &alpha,
 | 
						|
                &A[i2 * stride[0] * stride[1]], &stride[0], B, &nrowB,
 | 
						|
                &beta, &C[i2 * stride[0] * (transB ? ncolB : nrowB)], &stride[0]
 | 
						|
                FCONE FCONE);
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    /*
 | 
						|
    // (reference implementation)
 | 
						|
    // Tensor Times Matrix / Mode Product for `op(B) == B`
 | 
						|
    memset(c, 0, sizeC * sizeof(double));
 | 
						|
    for (int i2 = 0; i2 < stride[2]; ++i2) {
 | 
						|
        for (int i1 = 0; i1 < stride[1]; ++i1) { // stride[1] == ncols(B)
 | 
						|
            for (int j = 0; j < nrow; ++j) {
 | 
						|
                for (int i0 = 0; i0 < stride[0]; ++i0) {
 | 
						|
                    c[i0 + (j + i2 * nrow) * stride[0]] +=
 | 
						|
                        a[i0 + (i1 + i2 * stride[1]) * stride[0]] * b[j + i1 * nrow];
 | 
						|
                }
 | 
						|
            }
 | 
						|
        }
 | 
						|
    }
 | 
						|
    */
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
/**
 | 
						|
 * Tensor Times Matrix a.k.a. Mode Product
 | 
						|
 *
 | 
						|
 * @param A multi-dimensional array
 | 
						|
 * @param B matrix
 | 
						|
 * @param m mode index (1-indexed)
 | 
						|
 * @param op boolean if `B` is transposed
 | 
						|
 */
 | 
						|
extern SEXP R_ttm(SEXP A, SEXP B, SEXP m, SEXP op) {
 | 
						|
 | 
						|
    // get zero indexed mode
 | 
						|
    const int mode = Rf_asInteger(m) - 1;
 | 
						|
 | 
						|
    // get dimension attribute of A
 | 
						|
    SEXP dimA = Rf_getAttrib(A, R_DimSymbol);
 | 
						|
 | 
						|
    // operation on `B` (transposed or not)
 | 
						|
    const int transB = Rf_asLogical(op);
 | 
						|
 | 
						|
    // as well as `B`s dimensions
 | 
						|
    const int nrowB = Rf_nrows(B);
 | 
						|
    const int ncolB = Rf_ncols(B);
 | 
						|
 | 
						|
    // validate mode (mode must be smaller than the nr of dimensions)
 | 
						|
    if (mode < 0 || Rf_length(dimA) <= mode) {
 | 
						|
        Rf_error("Illegal mode");
 | 
						|
    }
 | 
						|
 | 
						|
    // and check if B is a matrix of non degenetate size
 | 
						|
    if (!Rf_isMatrix(B)) {
 | 
						|
        Rf_error("Expected a matrix as second argument");
 | 
						|
    }
 | 
						|
    if (!Rf_nrows(B) || !Rf_ncols(B)) {
 | 
						|
        Rf_error("Zero dimension detected");
 | 
						|
    }
 | 
						|
 | 
						|
    // check matching of dimensions
 | 
						|
    if (INTEGER(dimA)[mode] != (transB ? nrowB : ncolB)) {
 | 
						|
        Rf_error("Dimension missmatch");
 | 
						|
    }
 | 
						|
 | 
						|
    // calc nr of response elements (size of C)
 | 
						|
    // `prod(dim(C)) = prod(dim(A)[-mode]) * nrow(if(transB) t(B) else B)`
 | 
						|
    int sizeC = 1;
 | 
						|
    for (int i = 0; i < Rf_length(dimA); ++i) {
 | 
						|
        int size = INTEGER(dimA)[i];
 | 
						|
        // check for non-degenetate dimensions
 | 
						|
        if (!size) {
 | 
						|
            Rf_error("Zero dimension detected");
 | 
						|
        }
 | 
						|
        sizeC *= (i == mode) ? (transB ? ncolB : nrowB) : size;
 | 
						|
    }
 | 
						|
 | 
						|
    // create response object C
 | 
						|
    SEXP C = PROTECT(Rf_allocVector(REALSXP, sizeC));
 | 
						|
 | 
						|
    // Tensor Times Matrix / Mode Product
 | 
						|
    ttm(transB, mode,
 | 
						|
        INTEGER(dimA), Rf_length(dimA), nrowB, ncolB,
 | 
						|
        1.0, REAL(A), REAL(B), nrowB, 0.0, REAL(C));
 | 
						|
 | 
						|
    // finally, set result dimensions
 | 
						|
    SEXP dimC = PROTECT(Rf_allocVector(INTSXP, Rf_length(dimA)));
 | 
						|
    for (int i = 0; i < Rf_length(dimA); ++i) {
 | 
						|
        INTEGER(dimC)[i] = (i == mode) ? (transB ? ncolB : nrowB) : INTEGER(dimA)[i];
 | 
						|
    }
 | 
						|
    Rf_setAttrib(C, R_DimSymbol, dimC);
 | 
						|
 | 
						|
    // release C to the hands of the garbage collector
 | 
						|
    UNPROTECT(2);
 | 
						|
 | 
						|
    return C;
 | 
						|
}
 |