add: some more helper functions and alike ...

This commit is contained in:
Daniel Kapla 2025-03-27 13:36:11 +01:00
parent 57eb279e6a
commit e910db0377
9 changed files with 666 additions and 0 deletions

View File

@ -0,0 +1,44 @@
#' A multi-linear model
#'
#' @export
gmlm_mlm <- function(X, F, sample.axis) {
# get (set) problem (observation) dimensions
modes <- seq_along(dim(X))[-sample.axis]
dimX <- dim(X)[modes]
sample.size <- dim(X)[sample.axis]
if (!is.array(F)) {
dim(F) <- c(1L, sample.size)[(seq_along(dim(X)) == sample.axis) + 1L]
}
dimF <- dim(F)[modes]
# vectorize the tensor valued data (columns are the vectorized samples)
matX <- mat(X, modes)
matF <- mat(F, modes)
# center
matX <- matX - (meanX <- rowMeans(matX))
matF <- matF - rowMeans(matF)
# solve vectorized linear model
B <- tcrossprod(matX, matF) %*% pinv(tcrossprod(matF))
# decompose linear model solution as Kronecker product
betas <- approx.kron(B, dimX, dimF)
# reshape centered vectorized `X` and `F` into tensors (now, sample axis
# is last axis)
X <- `dim<-`(matX, c(dimX, sample.size))
F <- `dim<-`(matF, c(dimF, sample.size))
# and estimate covariances (sample axis is last axis)
Sigmas <- mcov(X - mlm(F, betas), sample.axis = length(dim(X)))
# finaly, invert covariances to get the scatter matrices `Omegas`
Omegas <- Map(solve, Sigmas)
list(
eta1 = array(meanX, dim = dim(X)[-sample.axis]),
betas = betas,
Omegas = Omegas
)
}

View File

@ -0,0 +1,131 @@
#' @export
icu_tensor_normal <- function(X, F, max.iter = 100L) {
if (is.null(dim(X)) || is.null(dim(F))) {
stop("Dim. of `X` and `F` nust be non-null")
}
if (length(dim(X)) != length(dim(F)) || length(dim(X)) < 2L) {
stop("`X` and `F` must have the same number of dimensions")
}
storage.mode(X) <- storage.mode(F) <- "double"
est <- .Call("C_icu_tensor_normal", X, F, as.integer(max.iter))
r <- length(dim(X)) - 1L
list(
eta = structure(est[[1]], dim = head(dim(X), -1)),
alphas = est[seq.int(2, r + 1)],
omegas = est[seq.int(r + 2, 2 * r + 1)]
)
}
icu_tensor_normal_ref <- function(X, F, max.iter) {
dimX <- head(dim(X), -1)
dimF <- head(dim(F), -1)
ord <- length(dimX)
sample.size <- tail(dim(X), 1)
# sizeX <- prod(dimX)
# sizeF <- prod(dimF)
# center X
meanX <- rowMeans(X, dims = ord)
X <- X - c(meanX)
# initial alphas
alphas <- Map(function(k) {
XX <- tcrossprod(mat(X, k)) / prod(dimX[-k], sample.size)
FF <- tcrossprod(mat(F, k)) / prod(dimF[-k], sample.size)
svdXX <- La.svd(XX, dimF[k], 0)
svdFF <- La.svd(FF, 0, dimX[k])
dd <- head(svdXX$d, min(dimX[k], dimF[k])) * head(svdFF$d, min(dimX[k], dimF[k]))
svdXX$u %*% diag(sqrt(dd)) %*% svdFF$vt
}, 1:ord)
# initial Omegas are set to the identity
Omegas <- Map(diag, dimX)
Deltas <- Map(solve, Omegas)
# iterative cyclic updating loop
for (iter in seq_len(max.iter)) {
# update alphas one at a time
for (j in seq_len(ord)) {
Fa <- mlm(F, alphas[-j], (1:ord)[-j])
FDa <- mlm(Fa, Deltas[-j], (1:ord)[-j])
alphas[[j]] <- t(solve(
mcrossprod(FDa, Fa, j), mcrossprod(Fa, X, j) %*% Omegas[[j]]
))
}
# update Omegas one at a time
for (k in seq_len(ord)) {
# update chenged Delta_j = Omega_j^-1
Deltas[[j]] <- solve(Omegas[[j]])
}
}
list(eta_bar = mlm(meanX, Omegas), alphas, Omegas)
}
# ################################################################################
# # setup dimensions
# n <- 1e3
# p <- c(3, 5, 11)
# q <- c(2, 7, 13)
# # create "true" GLM parameters
# eta1 <- array(rnorm(prod(p)), dim = p)
# alphas <- Map(matrix, Map(rnorm, p * q), p)
# Omegas <- Map(function(p_j) {
# solve(0.5^abs(outer(seq_len(p_j), seq_len(p_j), `-`)))
# }, p)
# true.params <- list(eta1, alphas, Omegas)
# # compute tensor normal parameters from GLM parameters
# Deltas <- Map(solve, Omegas)
# mu <- mlm(eta1, Deltas)
# # sample some test data
# sample.axis <- length(p) + 1L
# Fy <- array(rnorm(n * prod(q)), dim = c(q, n))
# X <- mlm(Fy, Map(`%*%`, Deltas, alphas)) + rtensornorm(n, mu, Deltas, sample.axis)
# ################################################################################
# kronperm <- function(A) {
# # force A to have even number of dimensions
# dim(A) <- c(dim(A), rep(1L, length(dim(A)) %% 2L))
# # compute axis permutation
# perm <- as.vector(t(matrix(seq_along(dim(A)), ncol = 2)[, 2:1]))
# # permute elements of A
# K <- aperm(A, perm, resize = FALSE)
# # collapse/set dimensions
# dim(K) <- head(dim(A), length(dim(A)) / 2) * tail(dim(A), length(dim(A)) / 2)
# K
# }
# p <- c(3, 5, 11)
# q <- c(2, 7, 13)
# A <- array(rnorm(prod(p)), dim = p)
# B <- array(rnorm(prod(q)), dim = q)
# Cs <- Map(function(d1, d2) matrix(rnorm(d1 * d2), d1), sample.int(5, length(p), TRUE), p)
# Ds <- Map(function(d1, d2) matrix(rnorm(d1 * d2), d1), sample.int(5, length(q), TRUE), q)
# all.equal(
# kronperm(mlm(outer(A, B), c(Cs, Ds))),
# kronecker(mlm(A, Cs), mlm(B, Ds))
# )
# all.equal(
# mlm(outer(A, B), c(Cs, Ds)),
# outer(mlm(A, Cs), mlm(B, Ds))
# )

View File

@ -0,0 +1,34 @@
#' Merge matrix-matrix multiplication over 3-axis of 3D arrays
#'
#' @param A 3D numeric array
#' @param B 3D numeric array
#' @returns 3D numeric array
#'
#' @examples
#' # Equivalent to the reference implementation
#' merge.matmul.reference <- function(A, B) {
#' C <- array(dim = c(nrow(A), ncol(B), dim(A)[3]))
#' for (i in seq_len(dim(A)[3])) {
#' C[, , i] <- A[, , i] %*% B[, , i]
#' }
#' C
#' }
#'
#' dimA <- c(3, 5, 101)
#' dimB <- c(5, 2, 101)
#' A <- array(rnorm(prod(dimA)), dim = dimA)
#' B <- array(rnorm(prod(dimB)), dim = dimB)
#'
#' C <- merge.matmul(A, B)
#' dim(C) # c(3, 2, 101)
#'
#' all.equal(
#' merge.matmul.reference(A, B),
#' merge.matmul(A, B)
#' )
#'
#' @export
merge.matmul <- function(A, B) {
storage.mode(A) <- storage.mode(B) <- "double"
.Call("C_merge_matmul", A, B, PACKAGE = "tensorPredictors")
}

View File

@ -0,0 +1,32 @@
#' Solve the matrix equation
#'
#' X A X + X = B
#'
#' for symmetric, positive definite `A`, `B`. This is a special case of an
#' algebraic Riccati equation.
#'
#' @examples
#' A <- crossprod(rmatrix(5, 5))
#' B <- crossprod(rmatrix(5, 5))
#'
#' X <- riccati(A, B)
#'
#' all.equal(X %*% A %*% X + X, B)
#'
#' @export
riccati <- function(A, B, max.iter = 20L) {
p <- nrow(A)
V <- diag(0.5 - (p < seq_len(2 * p)))
V[seq_len(p), seq_len(p) + p] <- A
V[seq_len(p) + p, seq_len(p)] <- B
for (iter in seq_len(max.iter)) {
V <- 0.5 * (V + pinv(V))
}
G <- V + diag(2 * p)
X <- pinv(G[, 1:p + p]) %*% -G[, 1:p]
structure(X, resid = norm(X %*% A %*% X + X - B, "F"))
}

View File

@ -0,0 +1,18 @@
#' Sylvester Equation
#'
#' A X + X B = C
#'
#' @examples
#' A <- rmatrix(5, 5)
#' B <- rmatrix(3, 3)
#' C <- rmatrix(5, 3)
#'
#' X <- sylvester(A, B, C)
#'
#' all.equal(A %*% X + X %*% B, C)
#'
#' @export
sylvester <- function(A, B, C) {
vecX <- solve(diag(ncol(C)) %x% A + t(B) %x% diag(nrow(C)), as.vector(C))
matrix(vecX, nrow(C), ncol(C))
}

View File

@ -0,0 +1,15 @@
p <- c(3L, 5L, 2L, 5L)
A <- array(rnorm(prod(p)), p)
eps <- 1
Gs <- TTSVD(A, eps = eps)
B <- Reduce(function(L, R)
ttt(L, R, length(dim(L)), 1)
, Gs)
stopifnot(sqrt(sum((A - B)^2)) <= eps * sqrt(sum(A^2)))
unlist(Map(function(M) qr(M)$rank, Map(function(k) mat(A, k), seq_along(dim(A)))))
attr(Gs, "TT.rank")

View File

@ -0,0 +1,291 @@
#include "R_api.h"
#include "ttm.h"
#include "mlm.h"
#include "solve.h"
#include "det.h"
/**
* Generalized Multi Linear Model fitting for the Tensor Normal Distribution
*
* @todo TODO: doc this!
*/
void gmlm_tensor_normal(
/* options */ const int max_iter, const int max_line_iter,
/* dims */ const int* dimX, const int* dimF,
const int ord, const int sample_size,
/* data */ const double* X, const double* F,
/* params */ double* meanX, double** alphas, double** Omegas,
int* error
) {
// product of X and F dimensions (dim of sample vec(X_i) and vec(F_i))
int sizeX = 1;
int sizeF = 1;
for (int k = 0; k < ord; ++k) {
sizeX *= dimX[k];
sizeF *= dimF[k];
}
// Compute maximum temporary and working memory size for all called subroutines
int work_size = mlm(FALSE, NULL, 0, dimX, dimF, ord, 0, NULL, NULL, NULL, 0, NULL, NULL);
int tmp_size = work_size;
// Update required working memory to also accomodate `solve` and `det`
for (int k = 0, size = 0; k < ord; ++k) {
solve(dimX[k], dimX[k], NULL, 0, NULL, 0, NULL, 0, NULL, &size);
work_size = work_size < size ? size : work_size;
(void)det(dimX[k], NULL, 0, NULL, &size);
work_size = work_size < size ? size : work_size;
}
// allocate temporary and working memory
double* work_mem = (double*)R_alloc(work_size, sizeof(double));
double* tmp1_mem = (double*)R_alloc(tmp_size, sizeof(double));
// tmp2_mem also needs to hold `sizeX^2` elements
tmp_size = tmp_size < sizeX * sizeX ? sizeX * sizeX : tmp_size;
double* tmp2_mem = (double*)R_alloc(tmp_size, sizeof(double));
// initialize momentum and previous iteration terms for alphas and Omegas
double** moment_alphas = (double**)R_alloc(ord, sizeof(double*));
double** moment_Omegas = (double**)R_alloc(ord, sizeof(double*));
double** prev_alphas = (double**)R_alloc(ord, sizeof(double*));
double** prev_Omegas = (double**)R_alloc(ord, sizeof(double*));
for (int k = 0; k < ord; ++k) {
moment_alphas[k] = (double*)R_alloc(dimX[k] * dimF[k], sizeof(double));
moment_Omegas[k] = (double*)R_alloc(dimX[k] * dimX[k], sizeof(double));
prev_alphas[k] = (double*)R_alloc(dimX[k] * dimF[k], sizeof(double));
prev_Omegas[k] = (double*)R_alloc(dimX[k] * dimX[k], sizeof(double));
memcpy(moment_alphas[k], alphas[k], dimX[k] * dimF[k] * sizeof(double));
memcpy(moment_Omegas[k], Omegas[k], dimX[k] * dimX[k] * sizeof(double));
memcpy(prev_alphas[k], alphas[k], dimX[k] * dimF[k] * sizeof(double));
memcpy(prev_Omegas[k], Omegas[k], dimX[k] * dimX[k] * sizeof(double));
}
// Allocate and compute initial values of `Deltas`, inverted `Omegas`
// Delta_k = Omega_k^-1 for k = 1, ..., ord
double** Deltas = (double**)R_alloc(ord, sizeof(double*));
for (int k = 0; k < ord; ++k) {
Deltas[k] = (double*)R_alloc(dimX[k] * dimX[k], sizeof(double));
solve(dimX[k], dimX[k], Omegas[k], dimX[k], NULL, 0,
Deltas[k], dimX[k], work_mem, error);
}
// modes array for MLM
int* modes = (int*)R_alloc(ord, sizeof(int));
for (int k = 0; k < ord; ++k) { modes[k] = k; }
///// Allready computed by initial parameter estimate provided by calling R code
// /* Step 1: compute mean(X) */
// memset(meanX, 0, sizeX * sizeof(double));
// for (int i = 0; i < sample_size; ++i) {
// F77_CALL(daxpy)(&sizeX, &d_one, X + i * sizeX, &i_one, meanX, &i_one);
// }
// const double inv_sample_size = 1.0 / (double)sample_size;
// F77_CALL(dscal)(&sizeX, &inv_sample_size, meanX, &i_one);
/* iteration 0, compute initial loss = log-likelihood */
// TODO: figure out how to properly avoid the use of explicitly cast
// see: https://stackoverflow.com/questions/12992407/warning-when-passing-non-const-parameter-to-a-function-that-expects-const-parame
// see: https://c-faq.com/ansi/constmismatch.html
double loss = 0.0;
for (int i = 0; i < sample_size; ++i) {
memcpy(tmp1_mem, meanX, sizeX * sizeof(double)); // tmp1_mem <- meanX
(void)mlm(FALSE, modes, ord, dimF, dimX, ord, // tmp1_mem <- mlm(F_i, alphas) + tmp1_mem
1.0, F + i * sizeF, (const double**)alphas, dimX, 1.0, tmp1_mem, work_mem);
(void)mlm(FALSE, modes, ord, dimX, dimX, ord, // tmp1_mem <- mlm(tmp1_mem, Deltas) = mean(X | Y)
1.0, tmp1_mem, (const double**)Deltas, dimX, 0.0, tmp1_mem, work_mem);
// tmp1_mem <- tmp1_mem - X_i
axpy(sizeX, -1.0, X + i * sizeX, 1, tmp1_mem, 1);
// tmp2_mem <- mlm(tmp1_mem, Omegas)
(void)mlm(FALSE, modes, ord, dimX, dimX, ord,
1.0, tmp1_mem, (const double**)Omegas, dimX, 0.0, tmp2_mem, work_mem);
// loss <- loss + <tmp1_mem, tmp2_mem>
loss += dot(sizeX, tmp1_mem, 1, tmp2_mem, 1);
}
loss /= (double)(2 * sample_size);
for (int k = 0; k < ord; ++k) {
// loss <- loss - 1 / 2 (prod_{j != k} p_j) log(det(Omega_k))
loss -= ((double)sizeX / (double)(2 * dimX[k]))
* log(det(dimX[k], Omegas[k], dimX[k], work_mem, error));
}
/* NAGD parameters */
int iter = 0; /* iteration counter */
double m0 = 0.0, m1 = 1.0; /* momentum extrapolation weights */
double gamma = 0.61803398875; /* line search step scaling */
double step_size = 1e-2; /* initial step size */
/* Main parameter descent loop */
for (iter = 0; iter < max_iter; ++iter) {
/* momentum extrapolation */
// `M <- theta_t + (m_{t-1} - 1) / m_t (theta_{t} - theta_{t-1})`
for (int k = 0; k < ord; ++k) {
lincomb(dimX[k] * dimF[k],
1.0 + (m0 - 1.0) / m1, alphas[k], 1, (1.0 - m0) / m1, prev_alphas[k], 1,
moment_alphas[k], 1);
lincomb(dimX[k] * dimX[k],
1.0 + (m0 - 1.0) / m1, Omegas[k], 1, (1.0 - m0) / m1, prev_Omegas[k], 1,
moment_Omegas[k], 1);
}
/* Gradient at momentum extrapolated parameters */
// Delta_k <- moment_Omega_k^-1, for k = 1, ..., ord
for (int k = 0; k < ord; ++k) {
solve(dimX[k], dimX[k], moment_Omegas[k], dimX[k], NULL, 0,
Deltas[k], dimX[k], work_mem, error);
}
//
// line search (initialy try to increase the step size)
step_size /= gamma;
for (int line_iter = 0; line_iter < max_line_iter; ++line_iter) {
/* gradient update for gradients at momentum extrapolated parameters */
// iteration over observations
for (int i = 0; i < sample_size; ++i) {
// tmp1_mem <- meanX
memcpy(tmp1_mem, meanX, sizeX * sizeof(double));
// tmp1_mem <- mlm(F_i, moment_alphas) + tmp1_mem
(void)mlm(FALSE, modes, ord, dimF, dimX, ord,
1.0, F + i * sizeF, (const double**)moment_alphas, dimX, 1.0, tmp1_mem, work_mem);
// tmp1_mem <- mlm(tmp1_mem, Deltas)
(void)mlm(FALSE, modes, ord, dimX, dimX, ord,
1.0, tmp1_mem, (const double**)Deltas, dimX, 0.0, tmp1_mem, work_mem);
// TODO: continue
}
// decrease step size
step_size *= gamma;
}
}
for (int i = 0; i < ord; ++i) {
memcpy(Omegas[i], Deltas[i], dimX[i] * dimX[i] * sizeof(double));
}
alphas[0][0] = loss;
}
/**
* Generalized Multi Linear Model fitting for the Tensor Normal Distribution
* (`R` binding)
*/
extern SEXP R_gmlm_tensor_normal(
/* data */ SEXP X, SEXP Fy,
/* params */ SEXP meanX, SEXP alphas, SEXP Omegas,
/* options */ SEXP max_iter, SEXP max_line_iter
) {
// get dimension attribute of `X` and check if its a real valued tensor
SEXP dimX = Rf_getAttrib(X, R_DimSymbol);
if (!Rf_isReal(X) || Rf_isNull(dimX)) {
Rf_error("Param. `X` need to be a real valued array");
}
// subtract 1 from `ord` as the last axis is the sample axis
int ord = Rf_length(dimX) - 1;
// the same for `Fy` and compare their order (nr. of axis)
SEXP dimFy = Rf_getAttrib(Fy, R_DimSymbol);
if (!Rf_isReal(Fy) || Rf_isNull(dimFy)) {
Rf_error("Param. `Fy` need to be a real valued array");
}
if (ord + 1 != Rf_length(dimFy)) {
Rf_error("Dimension mismatch, order of `X` and `Fy` must be equal");
}
// get direct access to dimension of `X` and `Fy`
int* p = INTEGER(dimX);
int* q = INTEGER(dimFy);
// validate that `X` and `Fy` have same number of samples
if (p[ord] != q[ord]) {
Rf_error("Nr. samples in `X` is %d and in `Fy` is %d, they must be equal",
p[ord], q[ord]);
}
int sample_size = p[ord];
// Check if `alphas` and `Omegas` are lists of length `ord`
if (!Rf_isNewList(alphas) || !Rf_isNewList(Omegas)
|| (Rf_length(alphas) != ord || Rf_length(Omegas) != ord)) {
Rf_error("Params `alphas` and `Omegas` must be lists of length %d", ord);
}
// Get memory access to the meanX and validate its dimensions
SEXP dimmeanX = Rf_getAttrib(meanX, R_DimSymbol);
if (Rf_length(dimmeanX) != ord) {
Rf_error("Dimension missmatch, meanX must have dim of `X` samples");
}
for (int i = 0; i < ord; ++i) {
if (INTEGER(dimmeanX)[i] != p[i]) {
Rf_error("Dimension missmatch, meanX must have dim of `X` samples");
}
}
// Extract initial `alphas` and `Omegas` while checking their dimensions
double** as = (double**)R_alloc(ord, sizeof(double*));
double** Os = (double**)R_alloc(ord, sizeof(double*));
for (int i = 0; i < ord; ++i) {
// Extract alpha_i and Omega_i from `alphas` and `Omegas`
SEXP alpha = VECTOR_ELT(alphas, i);
SEXP Omega = VECTOR_ELT(Omegas, i);
if (!(Rf_isMatrix(alpha) && Rf_isReal(alpha))
|| !(Rf_isMatrix(Omega) && Rf_isReal(Omega))) {
Rf_error("Every alpha and Omega must be a real matrix");
}
// check dimensions
int* dim = INTEGER(Rf_getAttrib(alpha, R_DimSymbol));
if (dim[0] != p[i] || dim[1] != q[i]) {
Rf_error("Dimension missmatch between dim(alpha_%d) = (%d, %d) "
"and dim(X)[%d] = %d and dim(Fy)[%d] = %d",
i + 1, dim[0], dim[1], i + 1, p[i], i + 1, q[i]);
}
dim = INTEGER(Rf_getAttrib(Omega, R_DimSymbol));
if (dim[0] != p[i] || dim[1] != p[i]) {
Rf_error("Dimension missmatch between dim(Omega_%d) = (%d, %d) and dim(X)[%d] = %d",
i + 1, dim[0], dim[1], i + 1, p[i]);
}
as[i] = REAL(alpha);
Os[i] = REAL(Omega);
}
// Finally check algorithm options (configuration)
if (!Rf_isInteger(max_iter) || Rf_length(max_iter) < 1
|| !Rf_isInteger(max_line_iter) || Rf_length(max_line_iter) < 1) {
Rf_error("Unexpected max iterations, required to be pos int");
}
int error = 0;
gmlm_tensor_normal(
/* options */ INTEGER(max_iter)[0], INTEGER0(max_line_iter)[0],
/* dims */ p, q, ord, sample_size,
/* data */ REAL(X), REAL(Fy),
/* params */ REAL(meanX), as, Os,
&error
);
// check if error occured which is indicated by negative number of iterations
if (error) {
Rf_error("Error in 'gmlm_tensor_normal' with error code %d", error);
}
return R_NilValue;
}

View File

@ -0,0 +1,81 @@
#include "merge_matmul.h"
void merge_matmul(
const int nrowA, const int ncolA, const int nslices,
const int nrowB, const int ncolB,
const double alpha,
const double* A, const double* B,
const double beta,
double* C
) {
for (int k = 0; k < nslices; ++k) {
F77_CALL(dgemm)("N", "N", &nrowA, &ncolB, &ncolA,
&alpha,
&A[k * nrowA * ncolA], &nrowA,
&B[k * nrowB * ncolB], &nrowB,
&beta,
&C[k * nrowA * ncolB], &nrowA
FCONE FCONE);
}
}
/**
* Apply matrix-matrix multiplication over 3rd axis of 3D arrays
*
* @param A 3D array
* @param B 3D array
* @returns 3D array
*/
extern SEXP R_merge_matmul(SEXP A, SEXP B) {
// Check if both are 3D arrays
if (!(Rf_isArray(A) && Rf_isArray(B))) {
Rf_error("Both A and B must be 3D arrays");
}
// get object dimensions
SEXP dimA = Rf_getAttrib(A, R_DimSymbol);
SEXP dimB = Rf_getAttrib(B, R_DimSymbol);
// Check dimension compatibility
if ((Rf_length(dimA) != 3) || (Rf_length(dimB) != 3)) {
Rf_error("Both A and B must be 3D arrays");
}
// Extract dimensions
int nrowA = INTEGER(dimA)[0];
int ncolA = INTEGER(dimA)[1];
int nslices = INTEGER(dimA)[2];
int nrowB = INTEGER(dimB)[0];
int ncolB = INTEGER(dimB)[1];
// Validate dimension compatibility
if (!nrowA | !ncolA | !nslices | !nrowB | !ncolB) {
Rf_error("Zero dimension detected");
}
if (nslices != INTEGER(dimB)[2]) {
Rf_error("Dimension misspatch dim(A)[3] != dim(B)[3]");
}
if (ncolA != nrowB) {
Rf_error("Dimension misspatch ncol(A) != nrow(B)");
}
// create response object C
int sizeC = nrowA * ncolB * nslices;
SEXP C = PROTECT(Rf_allocVector(REALSXP, sizeC));
SEXP dimC = PROTECT(Rf_allocVector(INTSXP, 3));
INTEGER(dimC)[0] = nrowA;
INTEGER(dimC)[1] = ncolB;
INTEGER(dimC)[2] = nslices;
Rf_setAttrib(C, R_DimSymbol, dimC);
// Apply matxir-matrix multiplication over 3rd array axis
merge_matmul(nrowA, ncolA, nslices, nrowB, ncolB,
1.0, REAL(A), REAL(B), 0.0, REAL(C)
);
// release C to the garbage collector
UNPROTECT(2);
return C;
}

View File

@ -0,0 +1,20 @@
#ifndef INCLUDE_GUARD_MERGE_MATMUL_H
#define INCLUDE_GUARD_MERGE_MATMUL_H
#include "R_api.h"
/**
* Merge matrix matrix multiplication over 3rd axis of 3D arrays
*
* @attention Assumes all parameters to be correct!
*/
void merge_matmul(
/* dims */ const int nrowA, const int ncolA, const int nslices,
const int nrowB, const int ncolB,
/* scalar */ const double alpha,
/* 3D array */ const double* A, const double* B,
/* scalar */ const double beta,
/* 3D array */ double* C
);
#endif /* INCLUDE_GUARD_MERGE_MATMUL_H */