tensor_predictors/tensorPredictors/src/ising_sample.c

73 lines
2.3 KiB
C

#ifndef INCLUDE_GUARD_ISING_SAMPLE_H
#define INCLUDE_GUARD_ISING_SAMPLE_H
#include "R_api.h"
#include "int_utils.h"
#include "ising_MCMC.h"
// .Call interface to draw from sample from the Ising model
extern SEXP R_ising_sample(SEXP _nr_samples, SEXP _params, SEXP _warmup) {
// Counts number of protected SEXP's to give them back to the garbage collector
size_t protect_count = 0;
// Parse and validate arguments
const size_t nr_samples = asUnsigned(_nr_samples);
if (nr_samples == 0 || nr_samples == NA_UNSIGNED) {
Rf_error("Invalid 'nr_samples' value, expected pos. integer");
}
const size_t warmup = asUnsigned(_warmup);
if (warmup == NA_UNSIGNED) {
Rf_error("Invalid 'warmup' value, expected non-negative integer");
}
// Determin parameter mode (natural parameter vector or symmetric matrix)
// Ether `m` for "Matrix" or `v` for "Vector"
const char param_type = Rf_isMatrix(_params) ? 'm' : 'v';
// In case of matrix parameters check for square matrix
if (param_type == 'm' && (Rf_nrows(_params) != Rf_ncols(_params))) {
Rf_error("Invalid 'params' value, exected square matrix");
}
// Get problem dimension from parameter size
const size_t dim = (param_type == 'm')
? Rf_nrows(_params)
: invTriag(Rf_length(_params));
if (!dim) {
Rf_error("Error determining dimension.");
}
// Ensure parameters are numeric
if (!Rf_isReal(_params)) {
_params = PROTECT(Rf_coerceVector(_params, REALSXP));
++protect_count;
}
double* params = REAL(_params);
// Allocate result sample
SEXP _X = PROTECT(Rf_allocMatrix(INTSXP, dim, nr_samples));
++protect_count;
int* X = INTEGER(_X);
// Call appropriate sampling routine for every sample to generate
GetRNGstate();
if (param_type == 'm') {
for (size_t sample = 0; sample < nr_samples; ++sample) {
ising_mcmc_mat(warmup, dim, dim, params, &X[sample * dim]);
}
} else {
for (size_t sample = 0; sample < nr_samples; ++sample) {
ising_mcmc_vech(warmup, dim, params, &X[sample * dim]);
}
}
PutRNGstate();
// Release protected SEXPs to the garbage collector
UNPROTECT(protect_count);
return _X;
}
#endif /* INCLUDE_GUARD_ISING_SAMPLE_H */