changed PSQTs,
fix: typo in data_gen.cpp inverting quiete position sampling, add: white and black to move support, update: eeg data example, add: position analytics as interace to the schachhoernchen Board class
This commit is contained in:
parent
61bd94bec8
commit
daefd3e7d1
|
@ -6,17 +6,53 @@ HCE <- function(positions) {
|
||||||
.Call(`_Rchess_HCE`, positions)
|
.Call(`_Rchess_HCE`, positions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#' Given a FEN (position) determines if its whites turn
|
||||||
|
isWhiteTurn <- function(positions) {
|
||||||
|
.Call(`_Rchess_isWhiteTurn`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if current side to move is in check
|
||||||
|
isCheck <- function(positions) {
|
||||||
|
.Call(`_Rchess_isCheck`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if the current position is a quiet position (no piece is attacked)
|
||||||
|
isQuiet <- function(positions) {
|
||||||
|
.Call(`_Rchess_isQuiet`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if position is terminal
|
||||||
|
#'
|
||||||
|
#' Checks if the position is a terminal position, meaning if the game ended
|
||||||
|
#' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||||
|
#' checked, therefore a seperate game history is required which the board
|
||||||
|
#' does NOT track.
|
||||||
|
#'
|
||||||
|
isTerminal <- function(positions) {
|
||||||
|
.Call(`_Rchess_isTerminal`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if checkmate is possible by material on the board
|
||||||
|
#'
|
||||||
|
#' Checks if there is sufficient mating material on the board, meaning if it
|
||||||
|
#' possible for any side to deliver a check mate. More specifically, it
|
||||||
|
#' checks if the pieces on the board are KK, KNK or KBK.
|
||||||
|
#'
|
||||||
|
isInsufficient <- function(positions) {
|
||||||
|
.Call(`_Rchess_isInsufficient`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
|
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
|
||||||
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
||||||
#' with scores filtered to be in in the range `score_min` to `score_max`.
|
#' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||||
#'
|
#'
|
||||||
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L) {
|
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
|
||||||
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count)
|
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
|
||||||
}
|
}
|
||||||
|
|
||||||
#' Human Crafted Evaluation
|
#' Human Crafted Evaluation
|
||||||
eval.psqt <- function(positions, psqt) {
|
eval.psqt <- function(positions, psqt, pawn_structure = FALSE, eval_rooks = FALSE, eval_king = FALSE) {
|
||||||
.Call(`_Rchess_eval_psqt`, positions, psqt)
|
.Call(`_Rchess_eval_psqt`, positions, psqt, pawn_structure, eval_rooks, eval_king)
|
||||||
}
|
}
|
||||||
|
|
||||||
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
|
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
|
||||||
|
|
|
@ -946,7 +946,7 @@ Score Board::evalPawns(enum piece color) const {
|
||||||
if (color == white) {
|
if (color == white) {
|
||||||
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
||||||
Index i = bitScanLS(sq);
|
Index i = bitScanLS(sq);
|
||||||
score += pieceSquareTables[pawn][i];
|
score += PSQT[pawn][i];
|
||||||
}
|
}
|
||||||
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
||||||
u64 backwards = pawns & ~isolated;
|
u64 backwards = pawns & ~isolated;
|
||||||
|
@ -963,7 +963,7 @@ Score Board::evalPawns(enum piece color) const {
|
||||||
} else { // color == black
|
} else { // color == black
|
||||||
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
||||||
Index i = bitScanLS(sq);
|
Index i = bitScanLS(sq);
|
||||||
score += pieceSquareTables[pawn][63 - i];
|
score += PSQT[pawn][63 - i];
|
||||||
}
|
}
|
||||||
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
||||||
u64 backwards = pawns & ~isolated;
|
u64 backwards = pawns & ~isolated;
|
||||||
|
@ -988,7 +988,7 @@ Score Board::evalKingSafety(enum piece color) const {
|
||||||
? bitScanLS( _bitBoard[white] & _bitBoard[king])
|
? bitScanLS( _bitBoard[white] & _bitBoard[king])
|
||||||
: bitScanLS(bitFlip<Rank>(_bitBoard[black] & _bitBoard[king]));
|
: bitScanLS(bitFlip<Rank>(_bitBoard[black] & _bitBoard[king]));
|
||||||
|
|
||||||
Score score = pieceSquareTables[king][kingSq];
|
Score score = PSQT[king][kingSq];
|
||||||
|
|
||||||
if ((fileIndex(kingSq) < 3) || (4 < fileIndex(kingSq))) { // King is castled
|
if ((fileIndex(kingSq) < 3) || (4 < fileIndex(kingSq))) { // King is castled
|
||||||
// Pawn shields are the least advanced pawns per file
|
// Pawn shields are the least advanced pawns per file
|
||||||
|
@ -1043,7 +1043,7 @@ Score Board::evalRooks(enum piece color) const {
|
||||||
for (u64 sq = rooks; sq; sq &= sq - 1) {
|
for (u64 sq = rooks; sq; sq &= sq - 1) {
|
||||||
Index sqIndex = bitScanLS(sq);
|
Index sqIndex = bitScanLS(sq);
|
||||||
// Piece square table (accounts for rook on seventh bonus)
|
// Piece square table (accounts for rook on seventh bonus)
|
||||||
score += pieceSquareTables[rook][sqIndex];
|
score += PSQT[rook][sqIndex];
|
||||||
|
|
||||||
// Add bonuses for semi-open and open files
|
// Add bonuses for semi-open and open files
|
||||||
if (bitMask<Square>(sqIndex) & openFiles) {
|
if (bitMask<Square>(sqIndex) & openFiles) {
|
||||||
|
@ -1063,22 +1063,10 @@ Score Board::evalRooks(enum piece color) const {
|
||||||
// position fen r1bq1rk1/pp2p1bp/2np4/5B2/nP3P2/N1P2N2/6PP/R1B1QRK1 b - - 0 4
|
// position fen r1bq1rk1/pp2p1bp/2np4/5B2/nP3P2/N1P2N2/6PP/R1B1QRK1 b - - 0 4
|
||||||
Score Board::evaluate() const {
|
Score Board::evaluate() const {
|
||||||
|
|
||||||
constexpr Score pstKingEndgame[64] = { // TODO: Proper parameters file, ...
|
|
||||||
0, 10, 20, 30, 30, 20, 10, 0,
|
|
||||||
10, 20, 30, 40, 40, 30, 20, 10,
|
|
||||||
20, 30, 40, 50, 50, 40, 30, 20,
|
|
||||||
30, 40, 50, 60, 60, 50, 40, 30,
|
|
||||||
30, 40, 50, 60, 60, 50, 40, 30,
|
|
||||||
20, 30, 40, 50, 50, 40, 30, 20,
|
|
||||||
10, 20, 30, 40, 40, 30, 20, 10,
|
|
||||||
0, 10, 20, 30, 30, 20, 10, 0
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr Score maxMaterial = 8 * pieceValues[pawn]
|
constexpr Score maxMaterial = 8 * pieceValues[pawn]
|
||||||
+ 2 * (pieceValues[rook] + pieceValues[knight] + pieceValues[bishop])
|
+ 2 * (pieceValues[rook] + pieceValues[knight] + pieceValues[bishop])
|
||||||
+ pieceValues[queen];
|
+ pieceValues[queen];
|
||||||
|
|
||||||
|
|
||||||
// Start score with material values
|
// Start score with material values
|
||||||
Score whiteMaterial = evalMaterial(white);
|
Score whiteMaterial = evalMaterial(white);
|
||||||
Score blackMaterial = evalMaterial(black);
|
Score blackMaterial = evalMaterial(black);
|
||||||
|
@ -1089,11 +1077,11 @@ Score Board::evaluate() const {
|
||||||
for (enum piece type : { queen, bishop, knight }) {
|
for (enum piece type : { queen, bishop, knight }) {
|
||||||
// White pieces
|
// White pieces
|
||||||
for (u64 sq = _bitBoard[white] & _bitBoard[type]; sq; sq &= sq - 1) {
|
for (u64 sq = _bitBoard[white] & _bitBoard[type]; sq; sq &= sq - 1) {
|
||||||
score += pieceSquareTables[type][bitScanLS(sq)];
|
score += PSQT[type][bitScanLS(sq)];
|
||||||
}
|
}
|
||||||
// and black pieces
|
// and black pieces
|
||||||
for (u64 sq = _bitBoard[black] & _bitBoard[type]; sq; sq &= sq - 1) {
|
for (u64 sq = _bitBoard[black] & _bitBoard[type]; sq; sq &= sq - 1) {
|
||||||
score -= pieceSquareTables[type][63 - bitScanLS(sq)];
|
score -= PSQT[type][63 - bitScanLS(sq)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1107,7 +1095,7 @@ Score Board::evaluate() const {
|
||||||
if (blackMaterial <= 1200) {
|
if (blackMaterial <= 1200) {
|
||||||
// Endgame:
|
// Endgame:
|
||||||
// No king safety, but more king mobility in the center
|
// No king safety, but more king mobility in the center
|
||||||
score += pstKingEndgame[bitScanLS(_bitBoard[white] & _bitBoard[king])];
|
score += PSQT[kingEG][bitScanLS(_bitBoard[white] & _bitBoard[king])];
|
||||||
} else {
|
} else {
|
||||||
// Middle Game:
|
// Middle Game:
|
||||||
// King safety weighted by opponents material (The less pieces the enemy
|
// King safety weighted by opponents material (The less pieces the enemy
|
||||||
|
@ -1117,7 +1105,7 @@ Score Board::evaluate() const {
|
||||||
}
|
}
|
||||||
// and the same for the black king with opposite sign
|
// and the same for the black king with opposite sign
|
||||||
if (whiteMaterial <= 1200) {
|
if (whiteMaterial <= 1200) {
|
||||||
score -= pstKingEndgame[bitScanLS(_bitBoard[black] & _bitBoard[king])];
|
score -= PSQT[kingEG][63 - bitScanLS(_bitBoard[black] & _bitBoard[king])];
|
||||||
} else {
|
} else {
|
||||||
score -= (5 * whiteMaterial * evalKingSafety(black)) / (4 * maxMaterial);
|
score -= (5 * whiteMaterial * evalKingSafety(black)) / (4 * maxMaterial);
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,7 +94,7 @@ public:
|
||||||
// add 128 to ensure the PST values are positive
|
// add 128 to ensure the PST values are positive
|
||||||
const Index t = color() == white ? to() : 63 - to();
|
const Index t = color() == white ? to() : 63 - to();
|
||||||
const Index f = color() == white ? from() : 63 - from();
|
const Index f = color() == white ? from() : 63 - from();
|
||||||
const uint32_t pst = pieceSquareTables[piece()][t] - pieceSquareTables[piece()][f] + 128;
|
const uint32_t pst = PSQT[piece()][t] - PSQT[piece()][f] + 128;
|
||||||
|
|
||||||
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
|
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,8 @@ enum piece {
|
||||||
bishop = 4,
|
bishop = 4,
|
||||||
rook = 5,
|
rook = 5,
|
||||||
queen = 6,
|
queen = 6,
|
||||||
king = 7
|
king = 7,
|
||||||
|
kingEG = 8 // Lookup index for king end game PSQT
|
||||||
};
|
};
|
||||||
|
|
||||||
enum square : Index {
|
enum square : Index {
|
||||||
|
@ -69,10 +70,10 @@ enum location {
|
||||||
constexpr Score pieceValues[8] = {
|
constexpr Score pieceValues[8] = {
|
||||||
0, 0, // white, black (irrelevant)
|
0, 0, // white, black (irrelevant)
|
||||||
100, // pawn
|
100, // pawn
|
||||||
300, // knight
|
295, // knight
|
||||||
300, // bishop
|
315, // bishop
|
||||||
500, // rook
|
450, // rook
|
||||||
900, // queen
|
870, // queen
|
||||||
0 // king (irrelevant, always 2 opposite kings)
|
0 // king (irrelevant, always 2 opposite kings)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -132,64 +133,76 @@ constexpr u64 kingMoveLookup[64] = {
|
||||||
using std::cerr;
|
using std::cerr;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Piece Square tables (from TSCP)
|
// Piece SQuare Tables (partially automated tuned tables via supervised
|
||||||
// see: https://www.chessprogramming.org/Simplified_Evaluation_Function
|
// optimization using stockfish [https://stockfishchess.org/] evaluated positions
|
||||||
constexpr Score pieceSquareTables[8][64] = {
|
// from the lichess database [https://database.lichess.org/])
|
||||||
|
// endgame table: https://www.chessprogramming.org/Simplified_Evaluation_Function
|
||||||
|
// Which is addapted by adding 50. then scaled by 2 / 3 and rounded.
|
||||||
|
constexpr Score PSQT[9][64] = {
|
||||||
{ }, { }, // white, black (empty)
|
{ }, { }, // white, black (empty)
|
||||||
{ // pawn (white)
|
{ // pawn (white)
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
5, 10, 15, 20, 20, 15, 10, 5,
|
109, 82, 89, 25, 25, 89, 82, 109,
|
||||||
4, 8, 12, 16, 16, 12, 8, 4,
|
21, 18, -3, 18, 18, -3, 18, 21,
|
||||||
3, 6, 9, 12, 12, 9, 6, 3,
|
-12, -1, -19, 6, 6, -19, -1, -12,
|
||||||
2, 4, 6, 8, 8, 6, 4, 2,
|
-25, -15, -22, 9, 9, -22, -15, -25,
|
||||||
1, 2, 3, -10, -10, 3, 2, 1,
|
-25, -11, -27, -23, -23, -27, -11, -25,
|
||||||
0, 0, 0, -40, -40, 0, 0, 0,
|
-25, -13, -23, -29, -29, -23, -13, -25,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
0, 0, 0, 0, 0, 0, 0, 0 },
|
||||||
{ // knight (white)
|
{ // knight (white)
|
||||||
-10, -10, -10, -10, -10, -10, -10, -10,
|
-90, -80, -18, 26, 26, -18, -80, -90,
|
||||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
-40, -13, 21, -22, -22, 21, -13, -40,
|
||||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
6, 2, 32, 38, 38, 32, 2, 6,
|
||||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
-9, -11, 22, 20, 20, 22, -11, -9,
|
||||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
-13, -11, 14, 2, 2, 14, -11, -13,
|
||||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
-25, -10, 2, 3, 3, 2, -10, -25,
|
||||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
-21, -54, -12, -8, -8, -12, -54, -21,
|
||||||
-10, -30, -10, -10, -10, -10, -30, -10 },
|
-76, -21, -38, -34, -34, -38, -21, -76 },
|
||||||
{ // bishop (white)
|
{ // bishop (white)
|
||||||
-10, -10, -10, -10, -10, -10, -10, -10,
|
-7, 19, 3, -21, -21, 3, 19, -7,
|
||||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
-15, -5, 6, 40, 40, 6, -5, -15,
|
||||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
12, 14, 18, 32, 32, 18, 14, 12,
|
||||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
-5, -2, 17, 26, 26, 17, -2, -5,
|
||||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
-19, -2, 2, 8, 8, 2, -2, -19,
|
||||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
2, 4, 2, 8, 8, 2, 4, 2,
|
||||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
-4, 8, 3, 1, 1, 3, 8, -4,
|
||||||
-10, -10, -20, -10, -10, -20, -10, -10 },
|
-31, -13, -7, -20, -20, -7, -13, -31 },
|
||||||
{ // rook (white)
|
{ // rook (white)
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-5, -2, 23, 40, 40, 23, -2, -5,
|
||||||
20, 20, 20, 20, 20, 20, 20, 20, // rook on seventh bonus
|
18, 17, 42, 25, 25, 42, 17, 18,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
22, 14, 33, 40, 40, 33, 14, 22,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
21, 16, 20, 28, 28, 20, 16, 21,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-4, -13, -5, 3, 3, -5, -13, -4,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-20, -2, -3, -2, -2, -3, -2, -20,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-11, -13, 0, -6, -6, 0, -13, -11,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
-17, -4, 0, 7, 7, 0, -4, -17 },
|
||||||
{ // queen (white)
|
{ // queen (white)
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-55, -29, 59, 19, 19, 59, -29, -55,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
12, -18, 34, 85, 85, 34, -18, 12,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
33, 17, 31, 34, 34, 31, 17, 33,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
51, 16, 21, 18, 18, 21, 16, 51,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
-3, 24, 18, 26, 26, 18, 24, -3,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
11, 14, 24, 2, 2, 24, 14, 11,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0,
|
28, 5, 17, 15, 15, 17, 5, 28,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
1, -10, -14, 18, 18, -14, -10, 1 },
|
||||||
{ // king middle game (white)
|
{ // king middle game (white)
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
-20, -20, -20, -20, -20, -20, -20, -20,
|
-4, -4, -4, -4, -4, -4, -4, -4,
|
||||||
0, 20, 40, -20, 0, -20, 40, 20 }
|
24, 13, 3, -28, 2, -14, 15, 1 },
|
||||||
|
{ // king end game (white) // TODO: self/supervised tuning
|
||||||
|
0, 7, 13, 20, 20, 13, 7, 0,
|
||||||
|
13, 20, 27, 33, 33, 27, 20, 13,
|
||||||
|
13, 27, 47, 53, 53, 47, 27, 13,
|
||||||
|
13, 27, 53, 60, 60, 53, 27, 13,
|
||||||
|
13, 27, 53, 60, 60, 53, 27, 13,
|
||||||
|
13, 27, 47, 53, 53, 47, 27, 13,
|
||||||
|
13, 13, 33, 33, 33, 33, 13, 13,
|
||||||
|
0, 13, 13, 13, 13, 13, 13, 0 }
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif /* INCLUDE_GUARD_TYPES_H */
|
#endif /* INCLUDE_GUARD_TYPES_H */
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Human Crafted Evaluation
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::NumericVector HCE(const std::vector<Board>& positions) {
|
||||||
|
// Iterate all positions and call the static board evaluation
|
||||||
|
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) {
|
||||||
|
return (double)pos.evaluate() / 100.0;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
|
@ -21,9 +21,59 @@ BEGIN_RCPP
|
||||||
return rcpp_result_gen;
|
return rcpp_result_gen;
|
||||||
END_RCPP
|
END_RCPP
|
||||||
}
|
}
|
||||||
|
// isWhiteTurn
|
||||||
|
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isWhiteTurn(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isWhiteTurn(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isCheck
|
||||||
|
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isCheck(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isCheck(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isQuiet
|
||||||
|
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isQuiet(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isQuiet(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isTerminal
|
||||||
|
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isTerminal(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isTerminal(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isInsufficient
|
||||||
|
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isInsufficient(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isInsufficient(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
// data_gen
|
// data_gen
|
||||||
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count);
|
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
|
||||||
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP) {
|
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
|
||||||
BEGIN_RCPP
|
BEGIN_RCPP
|
||||||
Rcpp::RObject rcpp_result_gen;
|
Rcpp::RObject rcpp_result_gen;
|
||||||
Rcpp::RNGScope rcpp_rngScope_gen;
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
@ -33,18 +83,22 @@ BEGIN_RCPP
|
||||||
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
|
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
|
||||||
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
|
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
|
||||||
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
|
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
|
||||||
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count));
|
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
|
||||||
return rcpp_result_gen;
|
return rcpp_result_gen;
|
||||||
END_RCPP
|
END_RCPP
|
||||||
}
|
}
|
||||||
// eval_psqt
|
// eval_psqt
|
||||||
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt);
|
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt, const bool pawn_structure, const bool eval_rooks, const bool eval_king);
|
||||||
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP) {
|
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP, SEXP pawn_structureSEXP, SEXP eval_rooksSEXP, SEXP eval_kingSEXP) {
|
||||||
BEGIN_RCPP
|
BEGIN_RCPP
|
||||||
Rcpp::RObject rcpp_result_gen;
|
Rcpp::RObject rcpp_result_gen;
|
||||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
Rcpp::traits::input_parameter< const std::vector<Rcpp::NumericMatrix>& >::type psqt(psqtSEXP);
|
Rcpp::traits::input_parameter< const std::vector<Rcpp::NumericMatrix>& >::type psqt(psqtSEXP);
|
||||||
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt));
|
Rcpp::traits::input_parameter< const bool >::type pawn_structure(pawn_structureSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type eval_rooks(eval_rooksSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type eval_king(eval_kingSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt, pawn_structure, eval_rooks, eval_king));
|
||||||
return rcpp_result_gen;
|
return rcpp_result_gen;
|
||||||
END_RCPP
|
END_RCPP
|
||||||
}
|
}
|
||||||
|
@ -196,8 +250,13 @@ END_RCPP
|
||||||
|
|
||||||
static const R_CallMethodDef CallEntries[] = {
|
static const R_CallMethodDef CallEntries[] = {
|
||||||
{"_Rchess_HCE", (DL_FUNC) &_Rchess_HCE, 1},
|
{"_Rchess_HCE", (DL_FUNC) &_Rchess_HCE, 1},
|
||||||
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 6},
|
{"_Rchess_isWhiteTurn", (DL_FUNC) &_Rchess_isWhiteTurn, 1},
|
||||||
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 2},
|
{"_Rchess_isCheck", (DL_FUNC) &_Rchess_isCheck, 1},
|
||||||
|
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
|
||||||
|
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
|
||||||
|
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
|
||||||
|
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
|
||||||
|
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
|
||||||
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
|
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
|
||||||
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
|
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
|
||||||
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},
|
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Given a FEN (position) determines if its whites turn
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions) {
|
||||||
|
// Iterate all positions and call the static board evaluation
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isWhiteTurn(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if current side to move is in check
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isCheck(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if the current position is a quiet position (no piece is attacked)
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isQuiet(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if position is terminal
|
||||||
|
//'
|
||||||
|
//' Checks if the position is a terminal position, meaning if the game ended
|
||||||
|
//' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||||
|
//' checked, therefore a seperate game history is required which the board
|
||||||
|
//' does NOT track.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isTerminal(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if checkmate is possible by material on the board
|
||||||
|
//'
|
||||||
|
//' Checks if there is sufficient mating material on the board, meaning if it
|
||||||
|
//' possible for any side to deliver a check mate. More specifically, it
|
||||||
|
//' checks if the pieces on the board are KK, KNK or KBK.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isInsufficient(); }
|
||||||
|
);
|
||||||
|
}
|
|
@ -20,7 +20,8 @@ Rcpp::CharacterVector data_gen(
|
||||||
const float score_min = -5.0,
|
const float score_min = -5.0,
|
||||||
const float score_max = +5.0,
|
const float score_max = +5.0,
|
||||||
const bool quiet = false,
|
const bool quiet = false,
|
||||||
const int min_ply_count = 10
|
const int min_ply_count = 10,
|
||||||
|
const bool white_only = true
|
||||||
) {
|
) {
|
||||||
// Check parames
|
// Check parames
|
||||||
if (sample_size < 1) {
|
if (sample_size < 1) {
|
||||||
|
@ -103,10 +104,10 @@ Rcpp::CharacterVector data_gen(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject / Filter samples
|
// Reject / Filter samples
|
||||||
if (((int)pos.plyCount() < min_ply_count) // Filter early positions
|
if (((int)pos.plyCount() < min_ply_count) // early positions
|
||||||
|| (pos.sideToMove() == piece::black) // Filter white to move positions
|
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|
||||||
|| (score < score_min || score_max <= score) // filter scores out of slice
|
|| (score < score_min || score_max <= score) // scores out of slice
|
||||||
|| (quiet && pos.isQuiet())) // filter quiet positions (iff requested)
|
|| (quiet && !pos.isQuiet())) // quiet positions
|
||||||
{
|
{
|
||||||
reject_count++;
|
reject_count++;
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Human Crafted Evaluation
|
||||||
|
// [[Rcpp::export(name = "eval.psqt", rng = false)]]
|
||||||
|
Rcpp::NumericVector eval_psqt(
|
||||||
|
const std::vector<Board>& positions,
|
||||||
|
const std::vector<Rcpp::NumericMatrix>& psqt,
|
||||||
|
const bool pawn_structure = false,
|
||||||
|
const bool eval_rooks = false,
|
||||||
|
const bool eval_king = false
|
||||||
|
) {
|
||||||
|
// validate Piece Square Table count and sizes
|
||||||
|
if (psqt.size() != 6) {
|
||||||
|
Rcpp::stop("Expected exactly 6 PSQTs");
|
||||||
|
}
|
||||||
|
for (const auto table : psqt) {
|
||||||
|
if (table.nrow() != 8 || table.ncol() != 8) {
|
||||||
|
Rcpp::stop("PSQT table missmatch, all expected to be `8 x 8`");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create numeric vector by evaluating all positions
|
||||||
|
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||||
|
[&psqt, pawn_structure, eval_rooks, eval_king](
|
||||||
|
const Board& pos
|
||||||
|
) {
|
||||||
|
// Index to color/piece mapping (more robust)
|
||||||
|
enum piece colorLoopup[2] = { white, black };
|
||||||
|
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
|
||||||
|
|
||||||
|
// Score is the "inner product" of the "one-hot encoded" position
|
||||||
|
// and the piece square tables (PSQT)
|
||||||
|
double whiteMaterial = 0.0, blackMaterial = 0.0;
|
||||||
|
for (int piece = 0; piece < 6; ++piece) {
|
||||||
|
u64 piece_bb = pos.bb(pieceLookup[piece]);
|
||||||
|
// First the White (positive) pieces
|
||||||
|
for (u64 bb = pos.bb(piece::white) & piece_bb; bb; bb &= bb - 1) {
|
||||||
|
// Get piece on bitboard index (Least Significant Bit)
|
||||||
|
int index = bitScanLS(bb);
|
||||||
|
// Transpose to align with PSQT memory layout
|
||||||
|
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||||
|
whiteMaterial += psqt[piece][index];
|
||||||
|
}
|
||||||
|
// Second the black (negative) pieces (with flipped Ranks)
|
||||||
|
for (u64 bb = pos.bb(piece::black) & piece_bb; bb; bb &= bb - 1) {
|
||||||
|
// Get fliped board index
|
||||||
|
int index = bitScanLS(bb);
|
||||||
|
// Transpose to align with PSQT memory layout and flip ranks
|
||||||
|
// convert from whites perspective to blacks persepective
|
||||||
|
index = ((index & 7) << 3) | (7 - ((index & 56) >> 3));
|
||||||
|
blackMaterial += psqt[piece][index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (whiteMaterial - blackMaterial) / 100.0;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
devtools::load_all()
|
||||||
|
save_point <- sort(list.files(
|
||||||
|
"~/Work/tensorPredictors/dataAnalysis/chess/",
|
||||||
|
pattern = "save_point.*\\.Rdata",
|
||||||
|
full.names = TRUE
|
||||||
|
), decreasing = TRUE)[[1]]
|
||||||
|
load(save_point)
|
||||||
|
|
||||||
|
psqt <- Map(function(parts) matrix(rowSums(kronecker(parts[[2]], parts[[1]])), 8, 8), betas)
|
||||||
|
psqt <- Map(`-`, psqt[1:6], Map(function(table) table[8:1, ], psqt[7:12]))
|
||||||
|
|
||||||
|
eval.psqt("startpos", psqt)
|
||||||
|
|
||||||
|
*/
|
|
@ -15,8 +15,12 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
|
||||||
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
|
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
|
||||||
bitboards.attr("dim") = dims;
|
bitboards.attr("dim") = dims;
|
||||||
bitboards.attr("dimnames") = Rcpp::List::create(
|
bitboards.attr("dimnames") = Rcpp::List::create(
|
||||||
Rcpp::Named("rank") = Rcpp::CharacterVector::create("8", "7", "6", "5", "4", "3", "2", "1"),
|
Rcpp::Named("rank") = Rcpp::CharacterVector::create(
|
||||||
Rcpp::Named("file") = Rcpp::CharacterVector::create("a", "b", "c", "d", "e", "f", "g", "h"),
|
"8", "7", "6", "5", "4", "3", "2", "1"
|
||||||
|
),
|
||||||
|
Rcpp::Named("file") = Rcpp::CharacterVector::create(
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h"
|
||||||
|
),
|
||||||
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
|
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
|
||||||
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
|
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
|
||||||
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
|
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
|
||||||
|
@ -42,6 +46,8 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
|
||||||
int index = bitScanLS(bb);
|
int index = bitScanLS(bb);
|
||||||
// Transpose to align with printing as a Chess Board
|
// Transpose to align with printing as a Chess Board
|
||||||
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||||
|
// Flip black to move positions to whites point of view
|
||||||
|
index ^= pos.isWhiteTurn() ? 0 : 7;
|
||||||
bitboards[768 * i + 64 * slice + index] = 1;
|
bitboards[768 * i + 64 * slice + index] = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
// #include <Rcpp.h>
|
|
||||||
|
|
||||||
// #include "SchachHoernchen/utils.h"
|
|
||||||
// #include "SchachHoernchen/Board.h"
|
|
||||||
// #include "SchachHoernchen/uci.h"
|
|
||||||
|
|
||||||
// // [[Rcpp::export(name = "print.board", rng = false)]]
|
|
||||||
// void print_board(
|
|
||||||
// const Board& board,
|
|
||||||
// const bool check = true,
|
|
||||||
// const bool attacked = false,
|
|
||||||
// const bool pinned = false,
|
|
||||||
// const bool checkers = false
|
|
||||||
// ) {
|
|
||||||
// using Rcpp::Rcout;
|
|
||||||
|
|
||||||
// // Extract some properties
|
|
||||||
// piece color = board.sideToMove();
|
|
||||||
// piece enemy = static_cast<piece>(!color);
|
|
||||||
// u64 empty = ~(board.bb(piece::white) | board.bb(piece::black));
|
|
||||||
// u64 cKing = board.bb(king) & board.bb(color);
|
|
||||||
// // Construct highlight mask
|
|
||||||
// u64 attackMask = board.attacks(enemy, empty);
|
|
||||||
// u64 mask = 0;
|
|
||||||
// mask |= check ? attackMask & board.bb(color) & board.bb(king) : 0;
|
|
||||||
// mask |= attacked ? attackMask & ~board.bb(color) : 0;
|
|
||||||
// mask |= pinned ? board.pinned(enemy, cKing, empty) : 0;
|
|
||||||
// mask |= checkers ? board.checkers(enemy, cKing, empty) : 0;
|
|
||||||
|
|
||||||
// // print the board to console
|
|
||||||
// Rcout << "FEN: " << board.fen() << '\n';
|
|
||||||
// for (Index line = 0; line < 17; line++) {
|
|
||||||
// if (line % 2) {
|
|
||||||
// Index rankIndex = line / 2;
|
|
||||||
// Rcout << static_cast<char>('8' - rankIndex);
|
|
||||||
// for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
|
|
||||||
// Index squareIndex = 8 * rankIndex + fileIndex;
|
|
||||||
// if (bitMask<Square>(squareIndex) & mask) {
|
|
||||||
// if (board.piece(squareIndex)) {
|
|
||||||
// if (board.color(squareIndex) == black) {
|
|
||||||
// // bold + italic + underline + black (blue)
|
|
||||||
// Rcout << " | \033[1;3;4;94m";
|
|
||||||
// } else {
|
|
||||||
// // bold + italic + underline (+ white)
|
|
||||||
// Rcout << " | \033[1;3;4m";
|
|
||||||
// }
|
|
||||||
// Rcout << UCI::formatPiece(board.piece(squareIndex))
|
|
||||||
// << "\033[0m";
|
|
||||||
// } else {
|
|
||||||
// Rcout << " | .";
|
|
||||||
// }
|
|
||||||
// } else if (board.color(squareIndex) == black) {
|
|
||||||
// Rcout << " | \033[1m\033[94m" // bold + blue (black)
|
|
||||||
// << UCI::formatPiece(board.piece(squareIndex))
|
|
||||||
// << "\033[0m";
|
|
||||||
// } else if (board.color(squareIndex) == white) {
|
|
||||||
// Rcout << " | \033[1m\033[97m" // bold + white
|
|
||||||
// << UCI::formatPiece(board.piece(squareIndex))
|
|
||||||
// << "\033[0m";
|
|
||||||
// } else {
|
|
||||||
// Rcout << " | ";
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// Rcout << " |";
|
|
||||||
// } else {
|
|
||||||
// Rcout << " +---+---+---+---+---+---+---+---+";
|
|
||||||
// }
|
|
||||||
// Rcout << "\033[0K\n"; // clear rest of line (remove potential leftovers)
|
|
||||||
// }
|
|
||||||
// Rcout << " a b c d e f g h" << std::endl;
|
|
||||||
// }
|
|
|
@ -1,8 +1,6 @@
|
||||||
options(keep.source = TRUE, keep.source.pkgs = TRUE)
|
|
||||||
|
|
||||||
library(tensorPredictors)
|
library(tensorPredictors)
|
||||||
|
|
||||||
# Load as 3D predictors `X` and flat response `y`
|
# Load as 3D predictors `X` and flat response `y` and `F = y` with per person dim. 1 x 1
|
||||||
c(X, F, y) %<-% local({
|
c(X, F, y) %<-% local({
|
||||||
# Load from file
|
# Load from file
|
||||||
ds <- readRDS("eeg_data.rds")
|
ds <- readRDS("eeg_data.rds")
|
||||||
|
@ -30,42 +28,7 @@ c(X, F, y) %<-% local({
|
||||||
# fit a tensor normal model to the data sample axis 1 indexes persons)
|
# fit a tensor normal model to the data sample axis 1 indexes persons)
|
||||||
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
|
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
|
||||||
|
|
||||||
# Performa a LOO prediction
|
# plot the fitted mode wise reductions (for time and sensor axis)
|
||||||
y.hat <- sapply(seq_len(dim(X)[1L]), function(i) {
|
|
||||||
# Fit with i'th observation removes
|
|
||||||
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
|
|
||||||
|
|
||||||
# Reduce the entire data set
|
|
||||||
r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
|
|
||||||
# Fit a logit model on reduced data with i'th observation removed
|
|
||||||
logit <- glm(y ~ r, family = binomial(link = "logit"),
|
|
||||||
data = data.frame(y = y[-i], r = r[-i])
|
|
||||||
)
|
|
||||||
# predict i'th response given i'th reduced observation
|
|
||||||
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
|
|
||||||
# report progress
|
|
||||||
cat(sprintf("%3d/%d\n", i, dim(X)[1L]))
|
|
||||||
|
|
||||||
y.hat
|
|
||||||
})
|
|
||||||
|
|
||||||
### Classification performance measures
|
|
||||||
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
|
|
||||||
(acc <- mean(round(y.hat) == y)) # 0.7868852
|
|
||||||
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
|
|
||||||
(err <- mean(round(y.hat) != y)) # 0.2131148
|
|
||||||
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
|
|
||||||
(fpr <- mean((round(y.hat) == 1)[y == 0])) # 0.4
|
|
||||||
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
|
|
||||||
(tpr <- mean((round(y.hat) == 1)[y == 1])) # 0.8961039
|
|
||||||
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
|
|
||||||
(fnr <- mean((round(y.hat) == 0)[y == 1])) # 0.1038961
|
|
||||||
# tnr: True negative rate. P(Yhat = - | Y = -).
|
|
||||||
(tnr <- mean((round(y.hat) == 0)[y == 0])) # 0.6
|
|
||||||
# auc: Area Under the Curve
|
|
||||||
(auc <- pROC::roc(y, y.hat, quiet = TRUE)$auc) # 0.838961
|
|
||||||
|
|
||||||
|
|
||||||
with(fit.gmlm, {
|
with(fit.gmlm, {
|
||||||
par.reset <- par(mfrow = c(2, 1))
|
par.reset <- par(mfrow = c(2, 1))
|
||||||
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
|
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
|
||||||
|
@ -74,20 +37,83 @@ with(fit.gmlm, {
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# dimX <- c(4, 3, 5)
|
#' (2D)^2 PCA preprocessing
|
||||||
# Omegas <- Map(function(p) {
|
#'
|
||||||
# O <- matrix(rnorm(p^2), p)
|
#' @param tpc Number of "t"ime "p"rincipal "c"omponents.
|
||||||
# O %*% t(O)
|
#' @param ppc Number of "p"redictor "p"rincipal "c"omponents.
|
||||||
# }, dimX)
|
preprocess <- function(X, tpc, ppc) {
|
||||||
|
# Mode covariances (for predictor and time point modes)
|
||||||
|
c(Sigma_t, Sigma_p) %<-% mcov(X, sample.axis = 1L)
|
||||||
|
|
||||||
# # Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
# "predictor" (sensor) and time point principal components
|
||||||
# # which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
V_t <- svd(Sigma_t, tpc, 0L)$u
|
||||||
# # `Omega <- Reduce(kronecker, rev(Omegas))`.
|
V_p <- svd(Sigma_p, ppc, 0L)$u
|
||||||
|
|
||||||
# Omega <- Reduce(kronecker, rev(Omegas))
|
# reduce with mode wise PCs
|
||||||
# log(det(Omega)) / prod(nrow(Omega))
|
mlm(X, list(V_t, V_p), modes = 2:3, transposed = TRUE)
|
||||||
|
}
|
||||||
|
|
||||||
# (det.Omega <- sum(log(mapply(det, Omegas)) / dimX))
|
|
||||||
# sum(mapply(function(Omega) {
|
#' Leave-one-out prediction
|
||||||
# sum(log(eigen(Omega, TRUE, TRUE)$values))
|
#'
|
||||||
# }, Omegas) / dimX)
|
#' @param X 3D EEG data (preprocessed or not)
|
||||||
|
#' @param F binary responce `y` as a 3D tensor, every obs. is a 1 x 1 matrix
|
||||||
|
loo.predict <- function(X, F) {
|
||||||
|
sapply(seq_len(dim(X)[1L]), function(i) {
|
||||||
|
# Fit with i'th observation removes
|
||||||
|
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
|
||||||
|
|
||||||
|
# Reduce the entire data set
|
||||||
|
r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
|
||||||
|
# Fit a logit model on reduced data with i'th observation removed
|
||||||
|
logit <- glm(y ~ r, family = binomial(link = "logit"),
|
||||||
|
data = data.frame(y = y[-i], r = r[-i])
|
||||||
|
)
|
||||||
|
# predict i'th response given i'th reduced observation
|
||||||
|
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
|
||||||
|
# report progress
|
||||||
|
cat(sprintf("dim: (%d, %d) - %3d/%d\n", dim(X)[2L], dim(X)[3L], i, dim(X)[1L]))
|
||||||
|
|
||||||
|
y.hat
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Classification performance measures
|
||||||
|
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
|
||||||
|
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
|
||||||
|
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
|
||||||
|
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
|
||||||
|
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
|
||||||
|
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
|
||||||
|
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
|
||||||
|
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
|
||||||
|
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
|
||||||
|
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
|
||||||
|
# tnr: True negative rate. P(Yhat = - | Y = -).
|
||||||
|
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
|
||||||
|
# auc: Area Under the Curve
|
||||||
|
auc <- function(y.true, y.pred) as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE)$auc)
|
||||||
|
auc.sd <- function(y.true, y.pred) sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE)))
|
||||||
|
|
||||||
|
|
||||||
|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
|
||||||
|
y.hat.3.4 <- loo.predict(preprocess(X, 3, 4), F)
|
||||||
|
y.hat.15.15 <- loo.predict(preprocess(X, 15, 15), F)
|
||||||
|
y.hat.20.30 <- loo.predict(preprocess(X, 20, 30), F)
|
||||||
|
y.hat <- loo.predict(X, F)
|
||||||
|
|
||||||
|
# classification performance measures table by leave-one-out cross-validation
|
||||||
|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
|
||||||
|
sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
|
||||||
|
function(FUN) { match.fun(FUN)(y, y.pred) })
|
||||||
|
}))
|
||||||
|
#> y.hat.3.4 y.hat.15.15 y.hat.20.30 y.hat
|
||||||
|
#> acc 0.79508197 0.78688525 0.78688525 0.78688525
|
||||||
|
#> err 0.20491803 0.21311475 0.21311475 0.21311475
|
||||||
|
#> fpr 0.35555556 0.40000000 0.40000000 0.40000000
|
||||||
|
#> tpr 0.88311688 0.89610390 0.89610390 0.89610390
|
||||||
|
#> fnr 0.11688312 0.10389610 0.10389610 0.10389610
|
||||||
|
#> tnr 0.64444444 0.60000000 0.60000000 0.60000000
|
||||||
|
#> auc 0.85108225 0.83838384 0.83924964 0.83896104
|
||||||
|
#> auc.sd 0.03584791 0.03760531 0.03751307 0.03754553
|
||||||
|
|
Loading…
Reference in New Issue