tensor_predictors/dataAnalysis/chess/Rchess/src/eval_psqt.cpp

78 lines
2.9 KiB
C++

#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
//' Human Crafted Evaluation
// [[Rcpp::export(name = "eval.psqt", rng = false)]]
Rcpp::NumericVector eval_psqt(
const std::vector<Board>& positions,
const std::vector<Rcpp::NumericMatrix>& psqt,
const bool pawn_structure = false,
const bool eval_rooks = false,
const bool eval_king = false
) {
// validate Piece Square Table count and sizes
if (psqt.size() != 6) {
Rcpp::stop("Expected exactly 6 PSQTs");
}
for (const auto table : psqt) {
if (table.nrow() != 8 || table.ncol() != 8) {
Rcpp::stop("PSQT table missmatch, all expected to be `8 x 8`");
}
}
// create numeric vector by evaluating all positions
return Rcpp::NumericVector(positions.begin(), positions.end(),
[&psqt, pawn_structure, eval_rooks, eval_king](
const Board& pos
) {
// Index to color/piece mapping (more robust)
enum piece colorLoopup[2] = { white, black };
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
// Score is the "inner product" of the "one-hot encoded" position
// and the piece square tables (PSQT)
double whiteMaterial = 0.0, blackMaterial = 0.0;
for (int piece = 0; piece < 6; ++piece) {
u64 piece_bb = pos.bb(pieceLookup[piece]);
// First the White (positive) pieces
for (u64 bb = pos.bb(piece::white) & piece_bb; bb; bb &= bb - 1) {
// Get piece on bitboard index (Least Significant Bit)
int index = bitScanLS(bb);
// Transpose to align with PSQT memory layout
index = ((index & 7) << 3) | ((index & 56) >> 3);
whiteMaterial += psqt[piece][index];
}
// Second the black (negative) pieces (with flipped Ranks)
for (u64 bb = pos.bb(piece::black) & piece_bb; bb; bb &= bb - 1) {
// Get fliped board index
int index = bitScanLS(bb);
// Transpose to align with PSQT memory layout and flip ranks
// convert from whites perspective to blacks persepective
index = ((index & 7) << 3) | (7 - ((index & 56) >> 3));
blackMaterial += psqt[piece][index];
}
}
return (whiteMaterial - blackMaterial) / 100.0;
}
);
}
/*
devtools::load_all()
save_point <- sort(list.files(
"~/Work/tensorPredictors/dataAnalysis/chess/",
pattern = "save_point.*\\.Rdata",
full.names = TRUE
), decreasing = TRUE)[[1]]
load(save_point)
psqt <- Map(function(parts) matrix(rowSums(kronecker(parts[[2]], parts[[1]])), 8, 8), betas)
psqt <- Map(`-`, psqt[1:6], Map(function(table) table[8:1, ], psqt[7:12]))
eval.psqt("startpos", psqt)
*/