tensor_predictors/mvbernoulli/inst/include/mvbernoulli.h

87 lines
2.4 KiB
C
Raw Normal View History

// Included by Rcpp through naming convention into the generated RcppExports.cpp
// file. This anables to use custom Rcpp types throughout the package.
#ifndef MVBERNOULLI_INCLUDE_GUARD_H
#define MVBERNOULLI_INCLUDE_GUARD_H
#include <vector>
#include <algorithm>
#include <RcppCommon.h>
#include "../../src/types.h"
// Custom type consersion declarations
namespace Rcpp {
// from R to C++
template <> MVBinary as(SEXP);
// from C++ to R
template <> SEXP wrap(const MVBinary&);
} /* namespace Rcpp */
#include <Rcpp.h>
// Custom type implementations
namespace Rcpp {
// from R to C++
template <>
MVBinary as(SEXP x) {
if ((TYPEOF(x) == LGLSXP || TYPEOF(x) == INTSXP) && Rf_isMatrix(x)) {
int nrow = Rf_nrows(x);
int ncol = Rf_ncols(x);
if (31 < ncol) {
Rcpp::stop("Event dimension too big, max is 31");
}
MVBinary binary(nrow, ncol);
// convert logical/integer vector to numeric representation
int* data = (TYPEOF(x) == LGLSXP) ? LOGICAL(x) : INTEGER(x);
for (int i = 0; i < nrow; ++i) {
uint32_t num = 0;
for (int j = 0; j < ncol; ++j) {
num |= static_cast<bool>(data[i + nrow * j]) * (1 << j);
}
binary.push_back(num);
}
return binary;
} else if ((TYPEOF(x) == INTSXP) && Rf_isVector(x)) {
int n = Rf_length(x);
SEXP pAttr = Rf_getAttrib(x, Rf_install("p"));
int p = -1;
if (TYPEOF(pAttr) == INTSXP) {
p = Rf_asInteger(pAttr);
} else if (TYPEOF(pAttr) == REALSXP) {
p = Rf_asInteger(pAttr);
} else {
Rcpp::stop("Unable to determin ncol (illegal `p` attribute)");
}
if (p < 2 || 31 < p) {
Rcpp::stop("Unable to determin ncol (illegal `p` attribute)");
}
return MVBinary(INTEGER(x), INTEGER(x) + n, p);
} else {
Rcpp::stop("Unexpected dim/type");
}
}
// from C++ to R
template <>
SEXP wrap(const MVBinary& binary) {
auto sexp = Rcpp::IntegerVector(binary.begin(), binary.end());
sexp.attr("class") = Rcpp::CharacterVector::create("mvbinary");
sexp.attr("p") = binary.dim();
return sexp;
}
} /* namespace Rcpp */
#endif /* MVBERNOULLI_INCLUDE_GUARD_H */