changed PSQTs,

fix: typo in data_gen.cpp inverting quiete position sampling,
add: white and black to move support,
update: eeg data example,
add: position analytics as interace to the schachhoernchen Board class
This commit is contained in:
Daniel Kapla 2024-01-10 17:28:55 +01:00
parent 61bd94bec8
commit daefd3e7d1
12 changed files with 426 additions and 218 deletions

View File

@ -6,17 +6,53 @@ HCE <- function(positions) {
.Call(`_Rchess_HCE`, positions)
}
#' Given a FEN (position) determines if its whites turn
isWhiteTurn <- function(positions) {
.Call(`_Rchess_isWhiteTurn`, positions)
}
#' Check if current side to move is in check
isCheck <- function(positions) {
.Call(`_Rchess_isCheck`, positions)
}
#' Check if the current position is a quiet position (no piece is attacked)
isQuiet <- function(positions) {
.Call(`_Rchess_isQuiet`, positions)
}
#' Check if position is terminal
#'
#' Checks if the position is a terminal position, meaning if the game ended
#' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
#' checked, therefore a seperate game history is required which the board
#' does NOT track.
#'
isTerminal <- function(positions) {
.Call(`_Rchess_isTerminal`, positions)
}
#' Check if checkmate is possible by material on the board
#'
#' Checks if there is sufficient mating material on the board, meaning if it
#' possible for any side to deliver a check mate. More specifically, it
#' checks if the pieces on the board are KK, KNK or KBK.
#'
isInsufficient <- function(positions) {
.Call(`_Rchess_isInsufficient`, positions)
}
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
#' with scores filtered to be in in the range `score_min` to `score_max`.
#'
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count)
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
}
#' Human Crafted Evaluation
eval.psqt <- function(positions, psqt) {
.Call(`_Rchess_eval_psqt`, positions, psqt)
eval.psqt <- function(positions, psqt, pawn_structure = FALSE, eval_rooks = FALSE, eval_king = FALSE) {
.Call(`_Rchess_eval_psqt`, positions, psqt, pawn_structure, eval_rooks, eval_king)
}
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array

View File

@ -946,7 +946,7 @@ Score Board::evalPawns(enum piece color) const {
if (color == white) {
for (u64 sq = pawns; sq; sq &= sq - 1) {
Index i = bitScanLS(sq);
score += pieceSquareTables[pawn][i];
score += PSQT[pawn][i];
}
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
u64 backwards = pawns & ~isolated;
@ -963,7 +963,7 @@ Score Board::evalPawns(enum piece color) const {
} else { // color == black
for (u64 sq = pawns; sq; sq &= sq - 1) {
Index i = bitScanLS(sq);
score += pieceSquareTables[pawn][63 - i];
score += PSQT[pawn][63 - i];
}
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
u64 backwards = pawns & ~isolated;
@ -988,7 +988,7 @@ Score Board::evalKingSafety(enum piece color) const {
? bitScanLS( _bitBoard[white] & _bitBoard[king])
: bitScanLS(bitFlip<Rank>(_bitBoard[black] & _bitBoard[king]));
Score score = pieceSquareTables[king][kingSq];
Score score = PSQT[king][kingSq];
if ((fileIndex(kingSq) < 3) || (4 < fileIndex(kingSq))) { // King is castled
// Pawn shields are the least advanced pawns per file
@ -1043,7 +1043,7 @@ Score Board::evalRooks(enum piece color) const {
for (u64 sq = rooks; sq; sq &= sq - 1) {
Index sqIndex = bitScanLS(sq);
// Piece square table (accounts for rook on seventh bonus)
score += pieceSquareTables[rook][sqIndex];
score += PSQT[rook][sqIndex];
// Add bonuses for semi-open and open files
if (bitMask<Square>(sqIndex) & openFiles) {
@ -1063,22 +1063,10 @@ Score Board::evalRooks(enum piece color) const {
// position fen r1bq1rk1/pp2p1bp/2np4/5B2/nP3P2/N1P2N2/6PP/R1B1QRK1 b - - 0 4
Score Board::evaluate() const {
constexpr Score pstKingEndgame[64] = { // TODO: Proper parameters file, ...
0, 10, 20, 30, 30, 20, 10, 0,
10, 20, 30, 40, 40, 30, 20, 10,
20, 30, 40, 50, 50, 40, 30, 20,
30, 40, 50, 60, 60, 50, 40, 30,
30, 40, 50, 60, 60, 50, 40, 30,
20, 30, 40, 50, 50, 40, 30, 20,
10, 20, 30, 40, 40, 30, 20, 10,
0, 10, 20, 30, 30, 20, 10, 0
};
constexpr Score maxMaterial = 8 * pieceValues[pawn]
+ 2 * (pieceValues[rook] + pieceValues[knight] + pieceValues[bishop])
+ pieceValues[queen];
// Start score with material values
Score whiteMaterial = evalMaterial(white);
Score blackMaterial = evalMaterial(black);
@ -1089,11 +1077,11 @@ Score Board::evaluate() const {
for (enum piece type : { queen, bishop, knight }) {
// White pieces
for (u64 sq = _bitBoard[white] & _bitBoard[type]; sq; sq &= sq - 1) {
score += pieceSquareTables[type][bitScanLS(sq)];
score += PSQT[type][bitScanLS(sq)];
}
// and black pieces
for (u64 sq = _bitBoard[black] & _bitBoard[type]; sq; sq &= sq - 1) {
score -= pieceSquareTables[type][63 - bitScanLS(sq)];
score -= PSQT[type][63 - bitScanLS(sq)];
}
}
@ -1107,7 +1095,7 @@ Score Board::evaluate() const {
if (blackMaterial <= 1200) {
// Endgame:
// No king safety, but more king mobility in the center
score += pstKingEndgame[bitScanLS(_bitBoard[white] & _bitBoard[king])];
score += PSQT[kingEG][bitScanLS(_bitBoard[white] & _bitBoard[king])];
} else {
// Middle Game:
// King safety weighted by opponents material (The less pieces the enemy
@ -1117,7 +1105,7 @@ Score Board::evaluate() const {
}
// and the same for the black king with opposite sign
if (whiteMaterial <= 1200) {
score -= pstKingEndgame[bitScanLS(_bitBoard[black] & _bitBoard[king])];
score -= PSQT[kingEG][63 - bitScanLS(_bitBoard[black] & _bitBoard[king])];
} else {
score -= (5 * whiteMaterial * evalKingSafety(black)) / (4 * maxMaterial);
}

View File

@ -94,7 +94,7 @@ public:
// add 128 to ensure the PST values are positive
const Index t = color() == white ? to() : 63 - to();
const Index f = color() == white ? from() : 63 - from();
const uint32_t pst = pieceSquareTables[piece()][t] - pieceSquareTables[piece()][f] + 128;
const uint32_t pst = PSQT[piece()][t] - PSQT[piece()][f] + 128;
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
}

View File

@ -39,7 +39,8 @@ enum piece {
bishop = 4,
rook = 5,
queen = 6,
king = 7
king = 7,
kingEG = 8 // Lookup index for king end game PSQT
};
enum square : Index {
@ -69,10 +70,10 @@ enum location {
constexpr Score pieceValues[8] = {
0, 0, // white, black (irrelevant)
100, // pawn
300, // knight
300, // bishop
500, // rook
900, // queen
295, // knight
315, // bishop
450, // rook
870, // queen
0 // king (irrelevant, always 2 opposite kings)
};
@ -132,64 +133,76 @@ constexpr u64 kingMoveLookup[64] = {
using std::cerr;
#endif
// Piece Square tables (from TSCP)
// see: https://www.chessprogramming.org/Simplified_Evaluation_Function
constexpr Score pieceSquareTables[8][64] = {
// Piece SQuare Tables (partially automated tuned tables via supervised
// optimization using stockfish [https://stockfishchess.org/] evaluated positions
// from the lichess database [https://database.lichess.org/])
// endgame table: https://www.chessprogramming.org/Simplified_Evaluation_Function
// Which is addapted by adding 50. then scaled by 2 / 3 and rounded.
constexpr Score PSQT[9][64] = {
{ }, { }, // white, black (empty)
{ // pawn (white)
0, 0, 0, 0, 0, 0, 0, 0,
5, 10, 15, 20, 20, 15, 10, 5,
4, 8, 12, 16, 16, 12, 8, 4,
3, 6, 9, 12, 12, 9, 6, 3,
2, 4, 6, 8, 8, 6, 4, 2,
1, 2, 3, -10, -10, 3, 2, 1,
0, 0, 0, -40, -40, 0, 0, 0,
109, 82, 89, 25, 25, 89, 82, 109,
21, 18, -3, 18, 18, -3, 18, 21,
-12, -1, -19, 6, 6, -19, -1, -12,
-25, -15, -22, 9, 9, -22, -15, -25,
-25, -11, -27, -23, -23, -27, -11, -25,
-25, -13, -23, -29, -29, -23, -13, -25,
0, 0, 0, 0, 0, 0, 0, 0 },
{ // knight (white)
-10, -10, -10, -10, -10, -10, -10, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, -30, -10, -10, -10, -10, -30, -10 },
-90, -80, -18, 26, 26, -18, -80, -90,
-40, -13, 21, -22, -22, 21, -13, -40,
6, 2, 32, 38, 38, 32, 2, 6,
-9, -11, 22, 20, 20, 22, -11, -9,
-13, -11, 14, 2, 2, 14, -11, -13,
-25, -10, 2, 3, 3, 2, -10, -25,
-21, -54, -12, -8, -8, -12, -54, -21,
-76, -21, -38, -34, -34, -38, -21, -76 },
{ // bishop (white)
-10, -10, -10, -10, -10, -10, -10, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, -10, -20, -10, -10, -20, -10, -10 },
-7, 19, 3, -21, -21, 3, 19, -7,
-15, -5, 6, 40, 40, 6, -5, -15,
12, 14, 18, 32, 32, 18, 14, 12,
-5, -2, 17, 26, 26, 17, -2, -5,
-19, -2, 2, 8, 8, 2, -2, -19,
2, 4, 2, 8, 8, 2, 4, 2,
-4, 8, 3, 1, 1, 3, 8, -4,
-31, -13, -7, -20, -20, -7, -13, -31 },
{ // rook (white)
0, 0, 0, 0, 0, 0, 0, 0,
20, 20, 20, 20, 20, 20, 20, 20, // rook on seventh bonus
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0 },
-5, -2, 23, 40, 40, 23, -2, -5,
18, 17, 42, 25, 25, 42, 17, 18,
22, 14, 33, 40, 40, 33, 14, 22,
21, 16, 20, 28, 28, 20, 16, 21,
-4, -13, -5, 3, 3, -5, -13, -4,
-20, -2, -3, -2, -2, -3, -2, -20,
-11, -13, 0, -6, -6, 0, -13, -11,
-17, -4, 0, 7, 7, 0, -4, -17 },
{ // queen (white)
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0 },
-55, -29, 59, 19, 19, 59, -29, -55,
12, -18, 34, 85, 85, 34, -18, 12,
33, 17, 31, 34, 34, 31, 17, 33,
51, 16, 21, 18, 18, 21, 16, 51,
-3, 24, 18, 26, 26, 18, 24, -3,
11, 14, 24, 2, 2, 24, 14, 11,
28, 5, 17, 15, 15, 17, 5, 28,
1, -10, -14, 18, 18, -14, -10, 1 },
{ // king middle game (white)
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-20, -20, -20, -20, -20, -20, -20, -20,
0, 20, 40, -20, 0, -20, 40, 20 }
-5, -5, -5, -5, -5, -5, -5, -5,
-5, -5, -5, -5, -5, -5, -5, -5,
-5, -5, -5, -5, -5, -5, -5, -5,
-5, -5, -5, -5, -5, -5, -5, -5,
-5, -5, -5, -5, -5, -5, -5, -5,
-5, -5, -5, -5, -5, -5, -5, -5,
-4, -4, -4, -4, -4, -4, -4, -4,
24, 13, 3, -28, 2, -14, 15, 1 },
{ // king end game (white) // TODO: self/supervised tuning
0, 7, 13, 20, 20, 13, 7, 0,
13, 20, 27, 33, 33, 27, 20, 13,
13, 27, 47, 53, 53, 47, 27, 13,
13, 27, 53, 60, 60, 53, 27, 13,
13, 27, 53, 60, 60, 53, 27, 13,
13, 27, 47, 53, 53, 47, 27, 13,
13, 13, 33, 33, 33, 33, 13, 13,
0, 13, 13, 13, 13, 13, 13, 0 }
};
#endif /* INCLUDE_GUARD_TYPES_H */

View File

@ -0,0 +1,16 @@
#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
//' Human Crafted Evaluation
// [[Rcpp::export(rng = false)]]
Rcpp::NumericVector HCE(const std::vector<Board>& positions) {
// Iterate all positions and call the static board evaluation
return Rcpp::NumericVector(positions.begin(), positions.end(),
[](const Board& pos) {
return (double)pos.evaluate() / 100.0;
}
);
}

View File

@ -21,9 +21,59 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// isWhiteTurn
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions);
RcppExport SEXP _Rchess_isWhiteTurn(SEXP positionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
rcpp_result_gen = Rcpp::wrap(isWhiteTurn(positions));
return rcpp_result_gen;
END_RCPP
}
// isCheck
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions);
RcppExport SEXP _Rchess_isCheck(SEXP positionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
rcpp_result_gen = Rcpp::wrap(isCheck(positions));
return rcpp_result_gen;
END_RCPP
}
// isQuiet
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions);
RcppExport SEXP _Rchess_isQuiet(SEXP positionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
rcpp_result_gen = Rcpp::wrap(isQuiet(positions));
return rcpp_result_gen;
END_RCPP
}
// isTerminal
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions);
RcppExport SEXP _Rchess_isTerminal(SEXP positionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
rcpp_result_gen = Rcpp::wrap(isTerminal(positions));
return rcpp_result_gen;
END_RCPP
}
// isInsufficient
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions);
RcppExport SEXP _Rchess_isInsufficient(SEXP positionsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
rcpp_result_gen = Rcpp::wrap(isInsufficient(positions));
return rcpp_result_gen;
END_RCPP
}
// data_gen
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP) {
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@ -33,18 +83,22 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count));
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
return rcpp_result_gen;
END_RCPP
}
// eval_psqt
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt);
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP) {
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt, const bool pawn_structure, const bool eval_rooks, const bool eval_king);
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP, SEXP pawn_structureSEXP, SEXP eval_rooksSEXP, SEXP eval_kingSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
Rcpp::traits::input_parameter< const std::vector<Rcpp::NumericMatrix>& >::type psqt(psqtSEXP);
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt));
Rcpp::traits::input_parameter< const bool >::type pawn_structure(pawn_structureSEXP);
Rcpp::traits::input_parameter< const bool >::type eval_rooks(eval_rooksSEXP);
Rcpp::traits::input_parameter< const bool >::type eval_king(eval_kingSEXP);
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt, pawn_structure, eval_rooks, eval_king));
return rcpp_result_gen;
END_RCPP
}
@ -196,8 +250,13 @@ END_RCPP
static const R_CallMethodDef CallEntries[] = {
{"_Rchess_HCE", (DL_FUNC) &_Rchess_HCE, 1},
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 6},
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 2},
{"_Rchess_isWhiteTurn", (DL_FUNC) &_Rchess_isWhiteTurn, 1},
{"_Rchess_isCheck", (DL_FUNC) &_Rchess_isCheck, 1},
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},

View File

@ -0,0 +1,56 @@
#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/Board.h"
//' Given a FEN (position) determines if its whites turn
// [[Rcpp::export(rng = false)]]
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions) {
// Iterate all positions and call the static board evaluation
return Rcpp::LogicalVector(positions.begin(), positions.end(),
[](const Board& pos) { return pos.isWhiteTurn(); }
);
}
//' Check if current side to move is in check
// [[Rcpp::export(rng = false)]]
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions) {
return Rcpp::LogicalVector(positions.begin(), positions.end(),
[](const Board& pos) { return pos.isCheck(); }
);
}
//' Check if the current position is a quiet position (no piece is attacked)
// [[Rcpp::export(rng = false)]]
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions) {
return Rcpp::LogicalVector(positions.begin(), positions.end(),
[](const Board& pos) { return pos.isQuiet(); }
);
}
//' Check if position is terminal
//'
//' Checks if the position is a terminal position, meaning if the game ended
//' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
//' checked, therefore a seperate game history is required which the board
//' does NOT track.
//'
// [[Rcpp::export(rng = false)]]
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions) {
return Rcpp::LogicalVector(positions.begin(), positions.end(),
[](const Board& pos) { return pos.isTerminal(); }
);
}
//' Check if checkmate is possible by material on the board
//'
//' Checks if there is sufficient mating material on the board, meaning if it
//' possible for any side to deliver a check mate. More specifically, it
//' checks if the pieces on the board are KK, KNK or KBK.
//'
// [[Rcpp::export(rng = false)]]
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions) {
return Rcpp::LogicalVector(positions.begin(), positions.end(),
[](const Board& pos) { return pos.isInsufficient(); }
);
}

View File

@ -20,7 +20,8 @@ Rcpp::CharacterVector data_gen(
const float score_min = -5.0,
const float score_max = +5.0,
const bool quiet = false,
const int min_ply_count = 10
const int min_ply_count = 10,
const bool white_only = true
) {
// Check parames
if (sample_size < 1) {
@ -103,10 +104,10 @@ Rcpp::CharacterVector data_gen(
}
// Reject / Filter samples
if (((int)pos.plyCount() < min_ply_count) // Filter early positions
|| (pos.sideToMove() == piece::black) // Filter white to move positions
|| (score < score_min || score_max <= score) // filter scores out of slice
|| (quiet && pos.isQuiet())) // filter quiet positions (iff requested)
if (((int)pos.plyCount() < min_ply_count) // early positions
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|| (score < score_min || score_max <= score) // scores out of slice
|| (quiet && !pos.isQuiet())) // quiet positions
{
reject_count++;
continue;

View File

@ -0,0 +1,78 @@
#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
//' Human Crafted Evaluation
// [[Rcpp::export(name = "eval.psqt", rng = false)]]
Rcpp::NumericVector eval_psqt(
const std::vector<Board>& positions,
const std::vector<Rcpp::NumericMatrix>& psqt,
const bool pawn_structure = false,
const bool eval_rooks = false,
const bool eval_king = false
) {
// validate Piece Square Table count and sizes
if (psqt.size() != 6) {
Rcpp::stop("Expected exactly 6 PSQTs");
}
for (const auto table : psqt) {
if (table.nrow() != 8 || table.ncol() != 8) {
Rcpp::stop("PSQT table missmatch, all expected to be `8 x 8`");
}
}
// create numeric vector by evaluating all positions
return Rcpp::NumericVector(positions.begin(), positions.end(),
[&psqt, pawn_structure, eval_rooks, eval_king](
const Board& pos
) {
// Index to color/piece mapping (more robust)
enum piece colorLoopup[2] = { white, black };
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
// Score is the "inner product" of the "one-hot encoded" position
// and the piece square tables (PSQT)
double whiteMaterial = 0.0, blackMaterial = 0.0;
for (int piece = 0; piece < 6; ++piece) {
u64 piece_bb = pos.bb(pieceLookup[piece]);
// First the White (positive) pieces
for (u64 bb = pos.bb(piece::white) & piece_bb; bb; bb &= bb - 1) {
// Get piece on bitboard index (Least Significant Bit)
int index = bitScanLS(bb);
// Transpose to align with PSQT memory layout
index = ((index & 7) << 3) | ((index & 56) >> 3);
whiteMaterial += psqt[piece][index];
}
// Second the black (negative) pieces (with flipped Ranks)
for (u64 bb = pos.bb(piece::black) & piece_bb; bb; bb &= bb - 1) {
// Get fliped board index
int index = bitScanLS(bb);
// Transpose to align with PSQT memory layout and flip ranks
// convert from whites perspective to blacks persepective
index = ((index & 7) << 3) | (7 - ((index & 56) >> 3));
blackMaterial += psqt[piece][index];
}
}
return (whiteMaterial - blackMaterial) / 100.0;
}
);
}
/*
devtools::load_all()
save_point <- sort(list.files(
"~/Work/tensorPredictors/dataAnalysis/chess/",
pattern = "save_point.*\\.Rdata",
full.names = TRUE
), decreasing = TRUE)[[1]]
load(save_point)
psqt <- Map(function(parts) matrix(rowSums(kronecker(parts[[2]], parts[[1]])), 8, 8), betas)
psqt <- Map(`-`, psqt[1:6], Map(function(table) table[8:1, ], psqt[7:12]))
eval.psqt("startpos", psqt)
*/

View File

@ -15,8 +15,12 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
bitboards.attr("dim") = dims;
bitboards.attr("dimnames") = Rcpp::List::create(
Rcpp::Named("rank") = Rcpp::CharacterVector::create("8", "7", "6", "5", "4", "3", "2", "1"),
Rcpp::Named("file") = Rcpp::CharacterVector::create("a", "b", "c", "d", "e", "f", "g", "h"),
Rcpp::Named("rank") = Rcpp::CharacterVector::create(
"8", "7", "6", "5", "4", "3", "2", "1"
),
Rcpp::Named("file") = Rcpp::CharacterVector::create(
"a", "b", "c", "d", "e", "f", "g", "h"
),
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
@ -42,6 +46,8 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
int index = bitScanLS(bb);
// Transpose to align with printing as a Chess Board
index = ((index & 7) << 3) | ((index & 56) >> 3);
// Flip black to move positions to whites point of view
index ^= pos.isWhiteTurn() ? 0 : 7;
bitboards[768 * i + 64 * slice + index] = 1;
}
}

View File

@ -1,71 +0,0 @@
// #include <Rcpp.h>
// #include "SchachHoernchen/utils.h"
// #include "SchachHoernchen/Board.h"
// #include "SchachHoernchen/uci.h"
// // [[Rcpp::export(name = "print.board", rng = false)]]
// void print_board(
// const Board& board,
// const bool check = true,
// const bool attacked = false,
// const bool pinned = false,
// const bool checkers = false
// ) {
// using Rcpp::Rcout;
// // Extract some properties
// piece color = board.sideToMove();
// piece enemy = static_cast<piece>(!color);
// u64 empty = ~(board.bb(piece::white) | board.bb(piece::black));
// u64 cKing = board.bb(king) & board.bb(color);
// // Construct highlight mask
// u64 attackMask = board.attacks(enemy, empty);
// u64 mask = 0;
// mask |= check ? attackMask & board.bb(color) & board.bb(king) : 0;
// mask |= attacked ? attackMask & ~board.bb(color) : 0;
// mask |= pinned ? board.pinned(enemy, cKing, empty) : 0;
// mask |= checkers ? board.checkers(enemy, cKing, empty) : 0;
// // print the board to console
// Rcout << "FEN: " << board.fen() << '\n';
// for (Index line = 0; line < 17; line++) {
// if (line % 2) {
// Index rankIndex = line / 2;
// Rcout << static_cast<char>('8' - rankIndex);
// for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
// Index squareIndex = 8 * rankIndex + fileIndex;
// if (bitMask<Square>(squareIndex) & mask) {
// if (board.piece(squareIndex)) {
// if (board.color(squareIndex) == black) {
// // bold + italic + underline + black (blue)
// Rcout << " | \033[1;3;4;94m";
// } else {
// // bold + italic + underline (+ white)
// Rcout << " | \033[1;3;4m";
// }
// Rcout << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else {
// Rcout << " | .";
// }
// } else if (board.color(squareIndex) == black) {
// Rcout << " | \033[1m\033[94m" // bold + blue (black)
// << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else if (board.color(squareIndex) == white) {
// Rcout << " | \033[1m\033[97m" // bold + white
// << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else {
// Rcout << " | ";
// }
// }
// Rcout << " |";
// } else {
// Rcout << " +---+---+---+---+---+---+---+---+";
// }
// Rcout << "\033[0K\n"; // clear rest of line (remove potential leftovers)
// }
// Rcout << " a b c d e f g h" << std::endl;
// }

View File

@ -1,8 +1,6 @@
options(keep.source = TRUE, keep.source.pkgs = TRUE)
library(tensorPredictors)
# Load as 3D predictors `X` and flat response `y`
# Load as 3D predictors `X` and flat response `y` and `F = y` with per person dim. 1 x 1
c(X, F, y) %<-% local({
# Load from file
ds <- readRDS("eeg_data.rds")
@ -30,8 +28,38 @@ c(X, F, y) %<-% local({
# fit a tensor normal model to the data sample axis 1 indexes persons)
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
# Performa a LOO prediction
y.hat <- sapply(seq_len(dim(X)[1L]), function(i) {
# plot the fitted mode wise reductions (for time and sensor axis)
with(fit.gmlm, {
par.reset <- par(mfrow = c(2, 1))
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
plot(betas[[2]], main = "Sensors", xlab = "Sensor Index", ylab = expression(beta[2]))
par(par.reset)
})
#' (2D)^2 PCA preprocessing
#'
#' @param tpc Number of "t"ime "p"rincipal "c"omponents.
#' @param ppc Number of "p"redictor "p"rincipal "c"omponents.
preprocess <- function(X, tpc, ppc) {
# Mode covariances (for predictor and time point modes)
c(Sigma_t, Sigma_p) %<-% mcov(X, sample.axis = 1L)
# "predictor" (sensor) and time point principal components
V_t <- svd(Sigma_t, tpc, 0L)$u
V_p <- svd(Sigma_p, ppc, 0L)$u
# reduce with mode wise PCs
mlm(X, list(V_t, V_p), modes = 2:3, transposed = TRUE)
}
#' Leave-one-out prediction
#'
#' @param X 3D EEG data (preprocessed or not)
#' @param F binary responce `y` as a 3D tensor, every obs. is a 1 x 1 matrix
loo.predict <- function(X, F) {
sapply(seq_len(dim(X)[1L]), function(i) {
# Fit with i'th observation removes
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
@ -44,50 +72,48 @@ y.hat <- sapply(seq_len(dim(X)[1L]), function(i) {
# predict i'th response given i'th reduced observation
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
# report progress
cat(sprintf("%3d/%d\n", i, dim(X)[1L]))
cat(sprintf("dim: (%d, %d) - %3d/%d\n", dim(X)[2L], dim(X)[3L], i, dim(X)[1L]))
y.hat
})
})
}
### Classification performance measures
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
(acc <- mean(round(y.hat) == y)) # 0.7868852
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
(err <- mean(round(y.hat) != y)) # 0.2131148
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
(fpr <- mean((round(y.hat) == 1)[y == 0])) # 0.4
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
(tpr <- mean((round(y.hat) == 1)[y == 1])) # 0.8961039
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
(fnr <- mean((round(y.hat) == 0)[y == 1])) # 0.1038961
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
# tnr: True negative rate. P(Yhat = - | Y = -).
(tnr <- mean((round(y.hat) == 0)[y == 0])) # 0.6
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
# auc: Area Under the Curve
(auc <- pROC::roc(y, y.hat, quiet = TRUE)$auc) # 0.838961
auc <- function(y.true, y.pred) as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE)$auc)
auc.sd <- function(y.true, y.pred) sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE)))
with(fit.gmlm, {
par.reset <- par(mfrow = c(2, 1))
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
plot(betas[[2]], main = "Sensors", xlab = "Sensor Index", ylab = expression(beta[2]))
par(par.reset)
})
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
y.hat.3.4 <- loo.predict(preprocess(X, 3, 4), F)
y.hat.15.15 <- loo.predict(preprocess(X, 15, 15), F)
y.hat.20.30 <- loo.predict(preprocess(X, 20, 30), F)
y.hat <- loo.predict(X, F)
# dimX <- c(4, 3, 5)
# Omegas <- Map(function(p) {
# O <- matrix(rnorm(p^2), p)
# O %*% t(O)
# }, dimX)
# # Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
# # which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
# # `Omega <- Reduce(kronecker, rev(Omegas))`.
# Omega <- Reduce(kronecker, rev(Omegas))
# log(det(Omega)) / prod(nrow(Omega))
# (det.Omega <- sum(log(mapply(det, Omegas)) / dimX))
# sum(mapply(function(Omega) {
# sum(log(eigen(Omega, TRUE, TRUE)$values))
# }, Omegas) / dimX)
# classification performance measures table by leave-one-out cross-validation
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
function(FUN) { match.fun(FUN)(y, y.pred) })
}))
#> y.hat.3.4 y.hat.15.15 y.hat.20.30 y.hat
#> acc 0.79508197 0.78688525 0.78688525 0.78688525
#> err 0.20491803 0.21311475 0.21311475 0.21311475
#> fpr 0.35555556 0.40000000 0.40000000 0.40000000
#> tpr 0.88311688 0.89610390 0.89610390 0.89610390
#> fnr 0.11688312 0.10389610 0.10389610 0.10389610
#> tnr 0.64444444 0.60000000 0.60000000 0.60000000
#> auc 0.85108225 0.83838384 0.83924964 0.83896104
#> auc.sd 0.03584791 0.03760531 0.03751307 0.03754553