101 lines
2.8 KiB
C
101 lines
2.8 KiB
C
#include "solve.h"
|
|
|
|
void solve(
|
|
/* dims */ const int dimA, const int nrhs,
|
|
/* matrix */ const double* A, const int ldA,
|
|
/* matrix */ const double* B, const int ldB,
|
|
/* matrix */ double* X, const int ldX,
|
|
double* work_mem, int* info
|
|
) {
|
|
// Compute required working memory size if requested
|
|
if (work_mem == NULL) {
|
|
*info = dimA * (dimA + 1);
|
|
return;
|
|
}
|
|
|
|
// Copy `A` to (continuous) working memory
|
|
for (int i = 0; i < dimA; ++i) {
|
|
memcpy(work_mem + i * dimA, A + i * ldA, dimA * sizeof(double));
|
|
}
|
|
|
|
// Copy `B` to `X` or set `X` to identity
|
|
if (B == NULL) {
|
|
double* X_col = X;
|
|
for (int j = 0; j < dimA; ++j, X_col += ldX) {
|
|
for (int i = 0; i < dimA; ++i) {
|
|
*(X_col + i) = (i == j) ? 1.0 : 0.0;
|
|
}
|
|
}
|
|
} else {
|
|
for (int i = 0; i < nrhs; ++i) {
|
|
memcpy(X + i * ldX, B + i * ldB, dimA * sizeof(double));
|
|
}
|
|
}
|
|
|
|
// Lapack routine DGESV to solve linear system A X = B which writes
|
|
// result into `A`, `B` which are copied into working memory and the result
|
|
// memory `X`
|
|
int error = 0;
|
|
F77_CALL(dgesv)(
|
|
/* dims */ &dimA, &nrhs,
|
|
/* matrix A */ work_mem, &dimA, /* [in,out] A -> P L U */
|
|
/* ipiv */ (int*)(work_mem + dimA * dimA), /* [out] */
|
|
/* matrix B */ X, &ldX, /* [in,out] B -> X */
|
|
&error /* [out] */
|
|
);
|
|
|
|
// update error flag
|
|
*info |= error;
|
|
}
|
|
|
|
/**
|
|
* R binding to `solve` which solves A X = B for X
|
|
*/
|
|
extern SEXP R_solve(SEXP A, SEXP B) {
|
|
// Check types
|
|
if (!(Rf_isReal(A) && Rf_isMatrix(A))
|
|
|| !(Rf_isReal(B) && Rf_isMatrix(B))) {
|
|
Rf_error("All arguments must be real valued matrices");
|
|
}
|
|
|
|
// check dimensions
|
|
if (Rf_nrows(A) != Rf_ncols(A)
|
|
|| Rf_ncols(A) != Rf_nrows(B)) {
|
|
Rf_error("Dimension missmatch");
|
|
}
|
|
|
|
// Allocate result matrix `X`
|
|
SEXP X = PROTECT(Rf_allocMatrix(REALSXP, Rf_nrows(B), Rf_ncols(B)));
|
|
|
|
// Allocate required working memory
|
|
int work_size = 0;
|
|
solve(
|
|
Rf_nrows(A), Rf_ncols(B),
|
|
NULL, Rf_nrows(A),
|
|
NULL, Rf_nrows(B),
|
|
NULL, Rf_nrows(X),
|
|
NULL, &work_size
|
|
);
|
|
double* work_mem = (double*)R_alloc(work_size, sizeof(double));
|
|
|
|
// Solve the system A X = B an write results into `X`
|
|
int error = 0;
|
|
solve(
|
|
Rf_nrows(A), Rf_ncols(B),
|
|
REAL(A), Rf_nrows(A),
|
|
REAL(B), Rf_nrows(B),
|
|
REAL(X), Rf_nrows(X),
|
|
work_mem, &error
|
|
);
|
|
|
|
// release `X` to the garbage collector
|
|
UNPROTECT(1);
|
|
|
|
// check error after unprotect
|
|
if (error) {
|
|
Rf_error("Solve ended with error code %d", error);
|
|
}
|
|
|
|
return X;
|
|
}
|