#include "solve.h" void solve( /* dims */ const int dimA, const int nrhs, /* matrix */ const double* A, const int ldA, /* matrix */ const double* B, const int ldB, /* matrix */ double* X, const int ldX, double* work_mem, int* info ) { // Compute required working memory size if requested if (work_mem == NULL) { *info = dimA * (dimA + 1); return; } // Copy `A` to (continuous) working memory for (int i = 0; i < dimA; ++i) { memcpy(work_mem + i * dimA, A + i * ldA, dimA * sizeof(double)); } // Copy `B` to `X` or set `X` to identity if (B == NULL) { double* X_col = X; for (int j = 0; j < dimA; ++j, X_col += ldX) { for (int i = 0; i < dimA; ++i) { *(X_col + i) = (i == j) ? 1.0 : 0.0; } } } else { for (int i = 0; i < nrhs; ++i) { memcpy(X + i * ldX, B + i * ldB, dimA * sizeof(double)); } } // Lapack routine DGESV to solve linear system A X = B which writes // result into `A`, `B` which are copied into working memory and the result // memory `X` int error = 0; F77_CALL(dgesv)( /* dims */ &dimA, &nrhs, /* matrix A */ work_mem, &dimA, /* [in,out] A -> P L U */ /* ipiv */ (int*)(work_mem + dimA * dimA), /* [out] */ /* matrix B */ X, &ldX, /* [in,out] B -> X */ &error /* [out] */ ); // update error flag *info |= error; } /** * R binding to `solve` which solves A X = B for X */ extern SEXP R_solve(SEXP A, SEXP B) { // Check types if (!(Rf_isReal(A) && Rf_isMatrix(A)) || !(Rf_isReal(B) && Rf_isMatrix(B))) { Rf_error("All arguments must be real valued matrices"); } // check dimensions if (Rf_nrows(A) != Rf_ncols(A) || Rf_ncols(A) != Rf_nrows(B)) { Rf_error("Dimension missmatch"); } // Allocate result matrix `X` SEXP X = PROTECT(Rf_allocMatrix(REALSXP, Rf_nrows(B), Rf_ncols(B))); // Allocate required working memory int work_size = 0; solve( Rf_nrows(A), Rf_ncols(B), NULL, Rf_nrows(A), NULL, Rf_nrows(B), NULL, Rf_nrows(X), NULL, &work_size ); double* work_mem = (double*)R_alloc(work_size, sizeof(double)); // Solve the system A X = B an write results into `X` int error = 0; solve( Rf_nrows(A), Rf_ncols(B), REAL(A), Rf_nrows(A), REAL(B), Rf_nrows(B), REAL(X), Rf_nrows(X), work_mem, &error ); // release `X` to the garbage collector UNPROTECT(1); // check error after unprotect if (error) { Rf_error("Solve ended with error code %d", error); } return X; }