From 917c81f37fd3b86d4b05ca17ab6c8c71b6651de4 Mon Sep 17 00:00:00 2001 From: daniel Date: Fri, 7 Feb 2025 13:46:10 +0100 Subject: [PATCH] update: data download and fen processing --- dataAnalysis/chess/Rchess/R/RcppExports.R | 4 +- dataAnalysis/chess/Rchess/src/RcppExports.cpp | 9 +- dataAnalysis/chess/Rchess/src/data_gen.cpp | 20 +-- dataAnalysis/chess/chess.R | 21 ++- dataAnalysis/chess/pgn2fen.cpp | 127 ++++++++++++++---- dataAnalysis/chess/preprocessing.sh | 17 ++- tensorPredictors/NAMESPACE | 1 + 7 files changed, 146 insertions(+), 53 deletions(-) diff --git a/dataAnalysis/chess/Rchess/R/RcppExports.R b/dataAnalysis/chess/Rchess/R/RcppExports.R index 153c053..55489b3 100644 --- a/dataAnalysis/chess/Rchess/R/RcppExports.R +++ b/dataAnalysis/chess/Rchess/R/RcppExports.R @@ -46,8 +46,8 @@ isInsufficient <- function(positions) { #' `gmlm_chess()` as data generator to provide random draws from a FEN data set #' with scores filtered to be in in the range `score_min` to `score_max`. #' -data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) { - .Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only) +data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, draw = TRUE, min_ply_count = 10L, white_only = TRUE) { + .Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only) } #' Human Crafted Evaluation diff --git a/dataAnalysis/chess/Rchess/src/RcppExports.cpp b/dataAnalysis/chess/Rchess/src/RcppExports.cpp index 8f1251b..43040e9 100644 --- a/dataAnalysis/chess/Rchess/src/RcppExports.cpp +++ b/dataAnalysis/chess/Rchess/src/RcppExports.cpp @@ -72,8 +72,8 @@ BEGIN_RCPP END_RCPP } // data_gen -Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only); -RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) { +Rcpp::List data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const bool draw, const int min_ply_count, const bool white_only); +RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP drawSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -82,9 +82,10 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP); Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP); Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP); + Rcpp::traits::input_parameter< const bool >::type draw(drawSEXP); Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP); Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP); - rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)); + rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only)); return rcpp_result_gen; END_RCPP } @@ -255,7 +256,7 @@ static const R_CallMethodDef CallEntries[] = { {"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1}, {"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1}, {"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1}, - {"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7}, + {"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 8}, {"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5}, {"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1}, {"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5}, diff --git a/dataAnalysis/chess/Rchess/src/data_gen.cpp b/dataAnalysis/chess/Rchess/src/data_gen.cpp index a8dd3c2..441d481 100644 --- a/dataAnalysis/chess/Rchess/src/data_gen.cpp +++ b/dataAnalysis/chess/Rchess/src/data_gen.cpp @@ -14,12 +14,13 @@ //' with scores filtered to be in in the range `score_min` to `score_max`. //' // [[Rcpp::export(name = "data.gen", rng = true)]] -Rcpp::CharacterVector data_gen( +Rcpp::List data_gen( const std::string& file, const int sample_size, const float score_min = -5.0, const float score_max = +5.0, const bool quiet = false, + const bool draw = true, const int min_ply_count = 10, const bool white_only = true ) { @@ -51,7 +52,7 @@ Rcpp::CharacterVector data_gen( } // Allocate output sample - Rcpp::CharacterVector _sample(sample_size); + Rcpp::CharacterVector _fens(sample_size); Rcpp::NumericVector _scores(sample_size); // Read and filter lines from FEN data base file @@ -106,15 +107,16 @@ Rcpp::CharacterVector data_gen( // Reject / Filter samples if (((int)pos.plyCount() < min_ply_count) // early positions || (white_only && (pos.sideToMove() == piece::black)) // white to move positions - || (score < score_min || score_max <= score) // scores out of slice - || (quiet && !pos.isQuiet())) // quiet positions + || (score < score_min || score_max < score) // scores out of slice + || (quiet && !pos.isQuiet()) // quiet positions + || (!draw && score == 0.0)) // drawn positions { reject_count++; continue; } // Everythings succeeded and ge got an appropriate sample in requested range - _sample[sample_count] = fen; + _fens[sample_count] = fen; _scores[sample_count] = score; ++sample_count; @@ -130,8 +132,8 @@ Rcpp::CharacterVector data_gen( } } - // Set scores as attribute to position sample - _sample.attr("scores") = _scores; - - return _sample; + return Rcpp::List::create( + Rcpp::Named("fens") = _fens, + Rcpp::Named("scores") = _scores + ); } diff --git a/dataAnalysis/chess/chess.R b/dataAnalysis/chess/chess.R index 11e7e57..4c798ab 100644 --- a/dataAnalysis/chess/chess.R +++ b/dataAnalysis/chess/chess.R @@ -1,6 +1,5 @@ library(tensorPredictors) library(Rchess) -library(mgcv) # for `gam()` (Generalized Additive Model) source("./gmlm_chess.R") @@ -16,9 +15,12 @@ data_set <- "lichess_db_standard_rated_2023-11.fen" # Function to draw samples `X` form the chess position `data_set` conditioned on # `Y` (position scores) to be in the interval `score_min` to `score_max`. data_gen <- function(batch_size, score_min, score_max) { - Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)) + data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE) + pos <- Rchess::fen2int(data$fens) + structure(pos, scores = data$scores) } + # Invoke specialized GMLM optimization routine for chess data fit.gmlm <- gmlm_chess(data_gen) @@ -26,6 +28,7 @@ fit.gmlm <- gmlm_chess(data_gen) ################################################################################ ### Reduction Interpretation and Validation ### ################################################################################ +library(mgcv) # for `gam()` (Generalized Additive Model) # load last save point (includes reduction as `betas`) save_point <- sort(list.files( @@ -43,6 +46,10 @@ sample_size <- 100000 fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE) # extract stockfish (non-static) position evaluation y <- attr(fens, "scores") +# remove poitions with exact draw evalualtion +draws <- which(y == 0.0) +y <- y[-draws] +fens <- fens[-draws] # Convert position into "One-Hot Encoded" / "Bit Board" tensor X <- Rchess::fen2int(fens) @@ -54,7 +61,7 @@ reducedX <- Reduce(rbind, Map(function(piece) { mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE) }, 1:12)) # Convert memory layout to contain vectorized observations in rows -reducedX <- t(`dim<-`(reducedX, c(48, sample_size))) +reducedX <- t(`dim<-`(reducedX, c(48, length(y)))) # set names for coefficient extraction from linear fit colnames(reducedX) <- as.vector(outer( unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "." @@ -86,15 +93,15 @@ psqt[["P"]][c(1, 8), ] <- 0 ### Validation by GAM fitted on reduced data formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+"))) -fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0) +fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX)) summary(fit.gam) # compair estimates with mean as baseline and static human crafted evaluation (HCE) -rmse.base <- sqrt(mean((mean(y) - y)^2)) +(rmse.base <- sqrt(mean((mean(y) - y)^2))) y.hce <- Rchess::HCE(fens) -rmse.hce <- sqrt(mean((y.hce - y)^2)) +(rmse.hce <- sqrt(mean((y.hce - y)^2))) y.hat <- predict(fit.gam, newdata = data.frame(reducedX)) -rmse.hat <- sqrt(mean((y.hat - y)^2)) +(rmse.hat <- sqrt(mean((y.hat - y)^2))) # Also extract R^2 (eval by hand or get from models) (r.sq.lm <- summary(fit)$r.squared) diff --git a/dataAnalysis/chess/pgn2fen.cpp b/dataAnalysis/chess/pgn2fen.cpp index 705f4e5..12eb559 100644 --- a/dataAnalysis/chess/pgn2fen.cpp +++ b/dataAnalysis/chess/pgn2fen.cpp @@ -10,24 +10,55 @@ #include "search.h" #include "uci.h" -static const std::string usage{"usage: pgn2fen [--scored] []"}; +static const std::string usage{"usage: pgn2fen [--scored] [--rating ] [--ply ] []"}; // Convert PGN (Portable Game Notation) input stream to single FENs // streamed to stdout -void pgn2fen(std::istream& input, const bool only_scored) { - +void pgn2fen( + std::istream& input, + const bool scored, + const unsigned long rating, + const unsigned long ply +) { // Instantiate Boards, the start of every game as well as the current state // of the Board while processing a PGN game Board startpos, pos; + // Parse white and black ELO ratings + unsigned long whiteElo = 0; + unsigned long blackElo = 0; + // Read input line by line std::string line; while (std::getline(input, line)) { - // Skip empty and metadata lines (every PGN game starts with ".") + // read rating metadata lines + if (rating != static_cast(-1)) { + // [WhiteElo "1111"] + // [BlackElo "999"] + try { + if (line.rfind("[WhiteElo \"", 0) != std::string::npos) { + whiteElo = std::stoul(line.substr(11)); + } else if (line.rfind("[BlackElo \"", 0) != std::string::npos) { + blackElo = std::stoul(line.substr(11)); + } + } catch (...) { + std::cerr << "ERROR: Parsing player rating metadata '" << line + << "' failed." << std::endl; + break; + } + } + + // Skip empty and further metadata lines (every PGN game starts with ".") if (line.empty() || line.front() == '[') { continue; } + // In case of rating requested, only parse game when rating is detected + if (rating != static_cast(-1) + && (whiteElo < rating || blackElo < rating)) { + continue; + } + // Reset position to the start position, every game starts here! pos = startpos; @@ -36,7 +67,7 @@ void pgn2fen(std::istream& input, const bool only_scored) { std::string count, san, token, eval; while (game >> count >> san >> token) { // Consume/Parse PGN comments - if (only_scored) { + if (scored) { // consume the comment and search for an evaluation bool has_score = false; while (game >> token) { @@ -65,21 +96,40 @@ void pgn2fen(std::istream& input, const bool only_scored) { bool parseError = false; Move move = UCI::parseSAN(san, pos, parseError); if (parseError) { - std::cerr << "[ Error ] Parsing '" << san << "' at position '" + std::cerr << "ERROR: Parsing '" << san << "' at position '" << pos.fen() << "' failed." << std::endl; + break; } move = pos.isLegal(move); // validate legality and extend move info if (move) { pos.make(move); } else { - std::cerr << "[ Error ] Encountered illegal move '" << san + std::cerr << "ERROR: Encountered illegal move '" << san << " (" << move << ") ' at position '" << pos.fen() << "'." << std::endl; break; } + // Skip positions with too small ply count + if (pos.plyCount() < ply) { + continue; + } + // Write positions - if (only_scored) { + if (scored && rating != static_cast(-1)) { + // Ingore "check mate in" scores (not relevant for eval training) + // Do this after "make move" in situations where the check mate + // was overlooked, leading to new positions + if (eval.length() && eval[0] == '#') { + continue; + } + // Otherwise, classic eval score to be parsed in centipawns + std::cout << pos.fen() << "; " << eval << "; " + << whiteElo << "; " << blackElo << '\n'; + } else if (rating != static_cast(-1)) { + // Otherwise, classic eval score to be parsed in centipawns + std::cout << pos.fen() << "; " << whiteElo << "; " << blackElo << '\n'; + } else if (scored) { // Ingore "check mate in" scores (not relevant for eval training) // Do this after "make move" in situations where the check mate // was overlooked, leading to new positions @@ -92,52 +142,75 @@ void pgn2fen(std::istream& input, const bool only_scored) { // Write only the position FEN std::cout << pos.fen() << '\n'; } + } + + // Reset ELO after every game to ensure that games without an elo + // metadata tag don't get the wrong rating from a previous game + whiteElo = 0; + blackElo = 0; } } int main(int argn, char* argv[]) { // Setup control variables - bool only_scored = false; + bool scored = false; + unsigned long rating = -1; + unsigned long ply = 0; + // unsigned min_rating = 0; std::string file = ""; // Parse command arguments - switch (argn) { - case 1: - break; - case 2: - if (std::string("--scored") == argv[1]) { - only_scored = true; + for (int i = 1; i < argn; ++i) { + if (std::string("--scored") == argv[i]) { + scored = true; + } else if (std::string("--rating") == argv[i]) { + if (i + 1 < argn) { + try { + rating = std::stoul(argv[++i]); + } catch (...) { + std::cerr << "ERROR: illegal --rating argument " << argv[i] << std::endl; + std::cout << usage << std::endl; + return 1; + } } else { - file = argv[1]; - } - break; - case 3: - if (std::string("--scored") != argv[1]) { std::cout << usage << std::endl; return 1; } - only_scored = true; - file = argv[2]; - break; - default: + } else if (std::string("--ply") == argv[i]) { + if (i + 1 < argn) { + try { + ply = std::stoul(argv[++i]); + } catch (...) { + std::cerr << "ERROR: illegal --ply argument " << argv[i] << std::endl; + std::cout << usage << std::endl; + return 1; + } + } else { + std::cout << usage << std::endl; + return 1; + } + } else if (file != "") { + file = argv[i]; + } else { std::cout << usage << std::endl; return 1; + } } // Invoke converter ether with file input or stdin if (file == "") { - pgn2fen(std::cin, only_scored); + pgn2fen(std::cin, scored, rating, ply); } else { // Open input file std::ifstream input(file); if (!input) { - std::cerr << "Error opening '" << file << "'" << std::endl; + std::cerr << "ERROR: opening '" << file << "' failed" << std::endl; return 1; } - pgn2fen(input, only_scored); + pgn2fen(input, scored, rating, ply); } return 0; diff --git a/dataAnalysis/chess/preprocessing.sh b/dataAnalysis/chess/preprocessing.sh index afa8a82..3dc1e64 100755 --- a/dataAnalysis/chess/preprocessing.sh +++ b/dataAnalysis/chess/preprocessing.sh @@ -4,6 +4,10 @@ # in November 2023 data=lichess_db_standard_rated_2023-11 +# Minimum "ELO" rating of black and white players +min_rating=2000 +min_ply=20 + # Check if file exists and download iff not if [ -f "${data}.fen" ]; then echo "File '${data}.fen' already exists, assuming job already done." @@ -21,11 +25,16 @@ else # the PGN data base into a list of FEN strings while filtering only # positions with evaluation. The `--scored` parameter specifies to extract # a position evaluation from the PGN and ONLY write positions with scores. - # That is, positions without a score are removed! + # That is, positions without a score are removed! Parameter `--rating` + # filters games where both players have at least the minimum rating, and + # `--ply` specifies to only consider positions from `$min_ply` onward. if [ -f "${data}.pgn.zst" ]; then - zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen + zstdcat ${data}.pgn.zst \ + | ./pgn2fen --scored --rating $min_rating --ply $min_ply \ + > ${data}.fen else - wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \ - | zstdcat | ./pgn2fen --scored > ${data}.fen + wget -qO- https://database.lichess.org/standard/${data}.pgn.zst | zstdcat \ + | ./pgn2fen --scored --rating $min_rating --ply $min_ply \ + > ${data}.fen fi fi diff --git a/tensorPredictors/NAMESPACE b/tensorPredictors/NAMESPACE index 47be8e1..313b03a 100644 --- a/tensorPredictors/NAMESPACE +++ b/tensorPredictors/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +S3method(merge,matmul) export("%<-%") export("%x_1%") export("%x_2%")