diff --git a/dataAnalysis/chess/Rchess/R/RcppExports.R b/dataAnalysis/chess/Rchess/R/RcppExports.R
index 153c053..55489b3 100644
--- a/dataAnalysis/chess/Rchess/R/RcppExports.R
+++ b/dataAnalysis/chess/Rchess/R/RcppExports.R
@@ -46,8 +46,8 @@ isInsufficient <- function(positions) {
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
#' with scores filtered to be in in the range `score_min` to `score_max`.
#'
-data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
- .Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
+data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, draw = TRUE, min_ply_count = 10L, white_only = TRUE) {
+ .Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only)
}
#' Human Crafted Evaluation
diff --git a/dataAnalysis/chess/Rchess/src/RcppExports.cpp b/dataAnalysis/chess/Rchess/src/RcppExports.cpp
index 8f1251b..43040e9 100644
--- a/dataAnalysis/chess/Rchess/src/RcppExports.cpp
+++ b/dataAnalysis/chess/Rchess/src/RcppExports.cpp
@@ -72,8 +72,8 @@ BEGIN_RCPP
END_RCPP
}
// data_gen
-Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
-RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
+Rcpp::List data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const bool draw, const int min_ply_count, const bool white_only);
+RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP drawSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -82,9 +82,10 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
+ Rcpp::traits::input_parameter< const bool >::type draw(drawSEXP);
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
- rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
+ rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only));
return rcpp_result_gen;
END_RCPP
}
@@ -255,7 +256,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
- {"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
+ {"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 8},
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
diff --git a/dataAnalysis/chess/Rchess/src/data_gen.cpp b/dataAnalysis/chess/Rchess/src/data_gen.cpp
index a8dd3c2..441d481 100644
--- a/dataAnalysis/chess/Rchess/src/data_gen.cpp
+++ b/dataAnalysis/chess/Rchess/src/data_gen.cpp
@@ -14,12 +14,13 @@
//' with scores filtered to be in in the range `score_min` to `score_max`.
//'
// [[Rcpp::export(name = "data.gen", rng = true)]]
-Rcpp::CharacterVector data_gen(
+Rcpp::List data_gen(
const std::string& file,
const int sample_size,
const float score_min = -5.0,
const float score_max = +5.0,
const bool quiet = false,
+ const bool draw = true,
const int min_ply_count = 10,
const bool white_only = true
) {
@@ -51,7 +52,7 @@ Rcpp::CharacterVector data_gen(
}
// Allocate output sample
- Rcpp::CharacterVector _sample(sample_size);
+ Rcpp::CharacterVector _fens(sample_size);
Rcpp::NumericVector _scores(sample_size);
// Read and filter lines from FEN data base file
@@ -106,15 +107,16 @@ Rcpp::CharacterVector data_gen(
// Reject / Filter samples
if (((int)pos.plyCount() < min_ply_count) // early positions
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
- || (score < score_min || score_max <= score) // scores out of slice
- || (quiet && !pos.isQuiet())) // quiet positions
+ || (score < score_min || score_max < score) // scores out of slice
+ || (quiet && !pos.isQuiet()) // quiet positions
+ || (!draw && score == 0.0)) // drawn positions
{
reject_count++;
continue;
}
// Everythings succeeded and ge got an appropriate sample in requested range
- _sample[sample_count] = fen;
+ _fens[sample_count] = fen;
_scores[sample_count] = score;
++sample_count;
@@ -130,8 +132,8 @@ Rcpp::CharacterVector data_gen(
}
}
- // Set scores as attribute to position sample
- _sample.attr("scores") = _scores;
-
- return _sample;
+ return Rcpp::List::create(
+ Rcpp::Named("fens") = _fens,
+ Rcpp::Named("scores") = _scores
+ );
}
diff --git a/dataAnalysis/chess/chess.R b/dataAnalysis/chess/chess.R
index 11e7e57..4c798ab 100644
--- a/dataAnalysis/chess/chess.R
+++ b/dataAnalysis/chess/chess.R
@@ -1,6 +1,5 @@
library(tensorPredictors)
library(Rchess)
-library(mgcv) # for `gam()` (Generalized Additive Model)
source("./gmlm_chess.R")
@@ -16,9 +15,12 @@ data_set <- "lichess_db_standard_rated_2023-11.fen"
# Function to draw samples `X` form the chess position `data_set` conditioned on
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
data_gen <- function(batch_size, score_min, score_max) {
- Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
+ data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)
+ pos <- Rchess::fen2int(data$fens)
+ structure(pos, scores = data$scores)
}
+
# Invoke specialized GMLM optimization routine for chess data
fit.gmlm <- gmlm_chess(data_gen)
@@ -26,6 +28,7 @@ fit.gmlm <- gmlm_chess(data_gen)
################################################################################
### Reduction Interpretation and Validation ###
################################################################################
+library(mgcv) # for `gam()` (Generalized Additive Model)
# load last save point (includes reduction as `betas`)
save_point <- sort(list.files(
@@ -43,6 +46,10 @@ sample_size <- 100000
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
# extract stockfish (non-static) position evaluation
y <- attr(fens, "scores")
+# remove poitions with exact draw evalualtion
+draws <- which(y == 0.0)
+y <- y[-draws]
+fens <- fens[-draws]
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
X <- Rchess::fen2int(fens)
@@ -54,7 +61,7 @@ reducedX <- Reduce(rbind, Map(function(piece) {
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
}, 1:12))
# Convert memory layout to contain vectorized observations in rows
-reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
+reducedX <- t(`dim<-`(reducedX, c(48, length(y))))
# set names for coefficient extraction from linear fit
colnames(reducedX) <- as.vector(outer(
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
@@ -86,15 +93,15 @@ psqt[["P"]][c(1, 8), ] <- 0
### Validation by GAM fitted on reduced data
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
-fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
+fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX))
summary(fit.gam)
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
-rmse.base <- sqrt(mean((mean(y) - y)^2))
+(rmse.base <- sqrt(mean((mean(y) - y)^2)))
y.hce <- Rchess::HCE(fens)
-rmse.hce <- sqrt(mean((y.hce - y)^2))
+(rmse.hce <- sqrt(mean((y.hce - y)^2)))
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
-rmse.hat <- sqrt(mean((y.hat - y)^2))
+(rmse.hat <- sqrt(mean((y.hat - y)^2)))
# Also extract R^2 (eval by hand or get from models)
(r.sq.lm <- summary(fit)$r.squared)
diff --git a/dataAnalysis/chess/pgn2fen.cpp b/dataAnalysis/chess/pgn2fen.cpp
index 705f4e5..12eb559 100644
--- a/dataAnalysis/chess/pgn2fen.cpp
+++ b/dataAnalysis/chess/pgn2fen.cpp
@@ -10,24 +10,55 @@
#include "search.h"
#include "uci.h"
-static const std::string usage{"usage: pgn2fen [--scored] []"};
+static const std::string usage{"usage: pgn2fen [--scored] [--rating ] [--ply ] []"};
// Convert PGN (Portable Game Notation) input stream to single FENs
// streamed to stdout
-void pgn2fen(std::istream& input, const bool only_scored) {
-
+void pgn2fen(
+ std::istream& input,
+ const bool scored,
+ const unsigned long rating,
+ const unsigned long ply
+) {
// Instantiate Boards, the start of every game as well as the current state
// of the Board while processing a PGN game
Board startpos, pos;
+ // Parse white and black ELO ratings
+ unsigned long whiteElo = 0;
+ unsigned long blackElo = 0;
+
// Read input line by line
std::string line;
while (std::getline(input, line)) {
- // Skip empty and metadata lines (every PGN game starts with ".")
+ // read rating metadata lines
+ if (rating != static_cast(-1)) {
+ // [WhiteElo "1111"]
+ // [BlackElo "999"]
+ try {
+ if (line.rfind("[WhiteElo \"", 0) != std::string::npos) {
+ whiteElo = std::stoul(line.substr(11));
+ } else if (line.rfind("[BlackElo \"", 0) != std::string::npos) {
+ blackElo = std::stoul(line.substr(11));
+ }
+ } catch (...) {
+ std::cerr << "ERROR: Parsing player rating metadata '" << line
+ << "' failed." << std::endl;
+ break;
+ }
+ }
+
+ // Skip empty and further metadata lines (every PGN game starts with ".")
if (line.empty() || line.front() == '[') {
continue;
}
+ // In case of rating requested, only parse game when rating is detected
+ if (rating != static_cast(-1)
+ && (whiteElo < rating || blackElo < rating)) {
+ continue;
+ }
+
// Reset position to the start position, every game starts here!
pos = startpos;
@@ -36,7 +67,7 @@ void pgn2fen(std::istream& input, const bool only_scored) {
std::string count, san, token, eval;
while (game >> count >> san >> token) {
// Consume/Parse PGN comments
- if (only_scored) {
+ if (scored) {
// consume the comment and search for an evaluation
bool has_score = false;
while (game >> token) {
@@ -65,21 +96,40 @@ void pgn2fen(std::istream& input, const bool only_scored) {
bool parseError = false;
Move move = UCI::parseSAN(san, pos, parseError);
if (parseError) {
- std::cerr << "[ Error ] Parsing '" << san << "' at position '"
+ std::cerr << "ERROR: Parsing '" << san << "' at position '"
<< pos.fen() << "' failed." << std::endl;
+ break;
}
move = pos.isLegal(move); // validate legality and extend move info
if (move) {
pos.make(move);
} else {
- std::cerr << "[ Error ] Encountered illegal move '" << san
+ std::cerr << "ERROR: Encountered illegal move '" << san
<< " (" << move
<< ") ' at position '" << pos.fen() << "'." << std::endl;
break;
}
+ // Skip positions with too small ply count
+ if (pos.plyCount() < ply) {
+ continue;
+ }
+
// Write positions
- if (only_scored) {
+ if (scored && rating != static_cast(-1)) {
+ // Ingore "check mate in" scores (not relevant for eval training)
+ // Do this after "make move" in situations where the check mate
+ // was overlooked, leading to new positions
+ if (eval.length() && eval[0] == '#') {
+ continue;
+ }
+ // Otherwise, classic eval score to be parsed in centipawns
+ std::cout << pos.fen() << "; " << eval << "; "
+ << whiteElo << "; " << blackElo << '\n';
+ } else if (rating != static_cast(-1)) {
+ // Otherwise, classic eval score to be parsed in centipawns
+ std::cout << pos.fen() << "; " << whiteElo << "; " << blackElo << '\n';
+ } else if (scored) {
// Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions
@@ -92,52 +142,75 @@ void pgn2fen(std::istream& input, const bool only_scored) {
// Write only the position FEN
std::cout << pos.fen() << '\n';
}
+
}
+
+ // Reset ELO after every game to ensure that games without an elo
+ // metadata tag don't get the wrong rating from a previous game
+ whiteElo = 0;
+ blackElo = 0;
}
}
int main(int argn, char* argv[]) {
// Setup control variables
- bool only_scored = false;
+ bool scored = false;
+ unsigned long rating = -1;
+ unsigned long ply = 0;
+ // unsigned min_rating = 0;
std::string file = "";
// Parse command arguments
- switch (argn) {
- case 1:
- break;
- case 2:
- if (std::string("--scored") == argv[1]) {
- only_scored = true;
+ for (int i = 1; i < argn; ++i) {
+ if (std::string("--scored") == argv[i]) {
+ scored = true;
+ } else if (std::string("--rating") == argv[i]) {
+ if (i + 1 < argn) {
+ try {
+ rating = std::stoul(argv[++i]);
+ } catch (...) {
+ std::cerr << "ERROR: illegal --rating argument " << argv[i] << std::endl;
+ std::cout << usage << std::endl;
+ return 1;
+ }
} else {
- file = argv[1];
- }
- break;
- case 3:
- if (std::string("--scored") != argv[1]) {
std::cout << usage << std::endl;
return 1;
}
- only_scored = true;
- file = argv[2];
- break;
- default:
+ } else if (std::string("--ply") == argv[i]) {
+ if (i + 1 < argn) {
+ try {
+ ply = std::stoul(argv[++i]);
+ } catch (...) {
+ std::cerr << "ERROR: illegal --ply argument " << argv[i] << std::endl;
+ std::cout << usage << std::endl;
+ return 1;
+ }
+ } else {
+ std::cout << usage << std::endl;
+ return 1;
+ }
+ } else if (file != "") {
+ file = argv[i];
+ } else {
std::cout << usage << std::endl;
return 1;
+ }
}
// Invoke converter ether with file input or stdin
if (file == "") {
- pgn2fen(std::cin, only_scored);
+ pgn2fen(std::cin, scored, rating, ply);
} else {
// Open input file
std::ifstream input(file);
if (!input) {
- std::cerr << "Error opening '" << file << "'" << std::endl;
+ std::cerr << "ERROR: opening '" << file << "' failed" << std::endl;
return 1;
}
- pgn2fen(input, only_scored);
+ pgn2fen(input, scored, rating, ply);
}
return 0;
diff --git a/dataAnalysis/chess/preprocessing.sh b/dataAnalysis/chess/preprocessing.sh
index afa8a82..3dc1e64 100755
--- a/dataAnalysis/chess/preprocessing.sh
+++ b/dataAnalysis/chess/preprocessing.sh
@@ -4,6 +4,10 @@
# in November 2023
data=lichess_db_standard_rated_2023-11
+# Minimum "ELO" rating of black and white players
+min_rating=2000
+min_ply=20
+
# Check if file exists and download iff not
if [ -f "${data}.fen" ]; then
echo "File '${data}.fen' already exists, assuming job already done."
@@ -21,11 +25,16 @@ else
# the PGN data base into a list of FEN strings while filtering only
# positions with evaluation. The `--scored` parameter specifies to extract
# a position evaluation from the PGN and ONLY write positions with scores.
- # That is, positions without a score are removed!
+ # That is, positions without a score are removed! Parameter `--rating`
+ # filters games where both players have at least the minimum rating, and
+ # `--ply` specifies to only consider positions from `$min_ply` onward.
if [ -f "${data}.pgn.zst" ]; then
- zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen
+ zstdcat ${data}.pgn.zst \
+ | ./pgn2fen --scored --rating $min_rating --ply $min_ply \
+ > ${data}.fen
else
- wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \
- | zstdcat | ./pgn2fen --scored > ${data}.fen
+ wget -qO- https://database.lichess.org/standard/${data}.pgn.zst | zstdcat \
+ | ./pgn2fen --scored --rating $min_rating --ply $min_ply \
+ > ${data}.fen
fi
fi
diff --git a/tensorPredictors/NAMESPACE b/tensorPredictors/NAMESPACE
index 47be8e1..313b03a 100644
--- a/tensorPredictors/NAMESPACE
+++ b/tensorPredictors/NAMESPACE
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand
+S3method(merge,matmul)
export("%<-%")
export("%x_1%")
export("%x_2%")