tensor_predictors/tensorPredictors/src/solve.c

101 lines
2.8 KiB
C

#include "solve.h"
void solve(
/* dims */ const int dimA, const int nrhs,
/* matrix */ const double* A, const int ldA,
/* matrix */ const double* B, const int ldB,
/* matrix */ double* X, const int ldX,
double* work_mem, int* info
) {
// Compute required working memory size if requested
if (work_mem == NULL) {
*info = dimA * (dimA + 1);
return;
}
// Copy `A` to (continuous) working memory
for (int i = 0; i < dimA; ++i) {
memcpy(work_mem + i * dimA, A + i * ldA, dimA * sizeof(double));
}
// Copy `B` to `X` or set `X` to identity
if (B == NULL) {
double* X_col = X;
for (int j = 0; j < dimA; ++j, X_col += ldX) {
for (int i = 0; i < dimA; ++i) {
*(X_col + i) = (i == j) ? 1.0 : 0.0;
}
}
} else {
for (int i = 0; i < nrhs; ++i) {
memcpy(X + i * ldX, B + i * ldB, dimA * sizeof(double));
}
}
// Lapack routine DGESV to solve linear system A X = B which writes
// result into `A`, `B` which are copied into working memory and the result
// memory `X`
int error = 0;
F77_CALL(dgesv)(
/* dims */ &dimA, &nrhs,
/* matrix A */ work_mem, &dimA, /* [in,out] A -> P L U */
/* ipiv */ (int*)(work_mem + dimA * dimA), /* [out] */
/* matrix B */ X, &ldX, /* [in,out] B -> X */
&error /* [out] */
);
// update error flag
*info |= error;
}
/**
* R binding to `solve` which solves A X = B for X
*/
extern SEXP R_solve(SEXP A, SEXP B) {
// Check types
if (!(Rf_isReal(A) && Rf_isMatrix(A))
|| !(Rf_isReal(B) && Rf_isMatrix(B))) {
Rf_error("All arguments must be real valued matrices");
}
// check dimensions
if (Rf_nrows(A) != Rf_ncols(A)
|| Rf_ncols(A) != Rf_nrows(B)) {
Rf_error("Dimension missmatch");
}
// Allocate result matrix `X`
SEXP X = PROTECT(Rf_allocMatrix(REALSXP, Rf_nrows(B), Rf_ncols(B)));
// Allocate required working memory
int work_size = 0;
solve(
Rf_nrows(A), Rf_ncols(B),
NULL, Rf_nrows(A),
NULL, Rf_nrows(B),
NULL, Rf_nrows(X),
NULL, &work_size
);
double* work_mem = (double*)R_alloc(work_size, sizeof(double));
// Solve the system A X = B an write results into `X`
int error = 0;
solve(
Rf_nrows(A), Rf_ncols(B),
REAL(A), Rf_nrows(A),
REAL(B), Rf_nrows(B),
REAL(X), Rf_nrows(X),
work_mem, &error
);
// release `X` to the garbage collector
UNPROTECT(1);
// check error after unprotect
if (error) {
Rf_error("Solve ended with error code %d", error);
}
return X;
}