update: data download and fen processing

This commit is contained in:
Daniel Kapla 2025-02-07 13:46:10 +01:00
parent b1f25b89da
commit 917c81f37f
7 changed files with 146 additions and 53 deletions

View File

@ -46,8 +46,8 @@ isInsufficient <- function(positions) {
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set #' `gmlm_chess()` as data generator to provide random draws from a FEN data set
#' with scores filtered to be in in the range `score_min` to `score_max`. #' with scores filtered to be in in the range `score_min` to `score_max`.
#' #'
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) { data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, draw = TRUE, min_ply_count = 10L, white_only = TRUE) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only) .Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only)
} }
#' Human Crafted Evaluation #' Human Crafted Evaluation

View File

@ -72,8 +72,8 @@ BEGIN_RCPP
END_RCPP END_RCPP
} }
// data_gen // data_gen
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only); Rcpp::List data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const bool draw, const int min_ply_count, const bool white_only);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) { RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP drawSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
BEGIN_RCPP BEGIN_RCPP
Rcpp::RObject rcpp_result_gen; Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::RNGScope rcpp_rngScope_gen;
@ -82,9 +82,10 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP); Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP); Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP); Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
Rcpp::traits::input_parameter< const bool >::type draw(drawSEXP);
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP); Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP); Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)); rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only));
return rcpp_result_gen; return rcpp_result_gen;
END_RCPP END_RCPP
} }
@ -255,7 +256,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1}, {"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1}, {"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1}, {"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7}, {"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 8},
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5}, {"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1}, {"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5}, {"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},

View File

@ -14,12 +14,13 @@
//' with scores filtered to be in in the range `score_min` to `score_max`. //' with scores filtered to be in in the range `score_min` to `score_max`.
//' //'
// [[Rcpp::export(name = "data.gen", rng = true)]] // [[Rcpp::export(name = "data.gen", rng = true)]]
Rcpp::CharacterVector data_gen( Rcpp::List data_gen(
const std::string& file, const std::string& file,
const int sample_size, const int sample_size,
const float score_min = -5.0, const float score_min = -5.0,
const float score_max = +5.0, const float score_max = +5.0,
const bool quiet = false, const bool quiet = false,
const bool draw = true,
const int min_ply_count = 10, const int min_ply_count = 10,
const bool white_only = true const bool white_only = true
) { ) {
@ -51,7 +52,7 @@ Rcpp::CharacterVector data_gen(
} }
// Allocate output sample // Allocate output sample
Rcpp::CharacterVector _sample(sample_size); Rcpp::CharacterVector _fens(sample_size);
Rcpp::NumericVector _scores(sample_size); Rcpp::NumericVector _scores(sample_size);
// Read and filter lines from FEN data base file // Read and filter lines from FEN data base file
@ -106,15 +107,16 @@ Rcpp::CharacterVector data_gen(
// Reject / Filter samples // Reject / Filter samples
if (((int)pos.plyCount() < min_ply_count) // early positions if (((int)pos.plyCount() < min_ply_count) // early positions
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions || (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|| (score < score_min || score_max <= score) // scores out of slice || (score < score_min || score_max < score) // scores out of slice
|| (quiet && !pos.isQuiet())) // quiet positions || (quiet && !pos.isQuiet()) // quiet positions
|| (!draw && score == 0.0)) // drawn positions
{ {
reject_count++; reject_count++;
continue; continue;
} }
// Everythings succeeded and ge got an appropriate sample in requested range // Everythings succeeded and ge got an appropriate sample in requested range
_sample[sample_count] = fen; _fens[sample_count] = fen;
_scores[sample_count] = score; _scores[sample_count] = score;
++sample_count; ++sample_count;
@ -130,8 +132,8 @@ Rcpp::CharacterVector data_gen(
} }
} }
// Set scores as attribute to position sample return Rcpp::List::create(
_sample.attr("scores") = _scores; Rcpp::Named("fens") = _fens,
Rcpp::Named("scores") = _scores
return _sample; );
} }

View File

@ -1,6 +1,5 @@
library(tensorPredictors) library(tensorPredictors)
library(Rchess) library(Rchess)
library(mgcv) # for `gam()` (Generalized Additive Model)
source("./gmlm_chess.R") source("./gmlm_chess.R")
@ -16,9 +15,12 @@ data_set <- "lichess_db_standard_rated_2023-11.fen"
# Function to draw samples `X` form the chess position `data_set` conditioned on # Function to draw samples `X` form the chess position `data_set` conditioned on
# `Y` (position scores) to be in the interval `score_min` to `score_max`. # `Y` (position scores) to be in the interval `score_min` to `score_max`.
data_gen <- function(batch_size, score_min, score_max) { data_gen <- function(batch_size, score_min, score_max) {
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)) data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)
pos <- Rchess::fen2int(data$fens)
structure(pos, scores = data$scores)
} }
# Invoke specialized GMLM optimization routine for chess data # Invoke specialized GMLM optimization routine for chess data
fit.gmlm <- gmlm_chess(data_gen) fit.gmlm <- gmlm_chess(data_gen)
@ -26,6 +28,7 @@ fit.gmlm <- gmlm_chess(data_gen)
################################################################################ ################################################################################
### Reduction Interpretation and Validation ### ### Reduction Interpretation and Validation ###
################################################################################ ################################################################################
library(mgcv) # for `gam()` (Generalized Additive Model)
# load last save point (includes reduction as `betas`) # load last save point (includes reduction as `betas`)
save_point <- sort(list.files( save_point <- sort(list.files(
@ -43,6 +46,10 @@ sample_size <- 100000
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE) fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
# extract stockfish (non-static) position evaluation # extract stockfish (non-static) position evaluation
y <- attr(fens, "scores") y <- attr(fens, "scores")
# remove poitions with exact draw evalualtion
draws <- which(y == 0.0)
y <- y[-draws]
fens <- fens[-draws]
# Convert position into "One-Hot Encoded" / "Bit Board" tensor # Convert position into "One-Hot Encoded" / "Bit Board" tensor
X <- Rchess::fen2int(fens) X <- Rchess::fen2int(fens)
@ -54,7 +61,7 @@ reducedX <- Reduce(rbind, Map(function(piece) {
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE) mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
}, 1:12)) }, 1:12))
# Convert memory layout to contain vectorized observations in rows # Convert memory layout to contain vectorized observations in rows
reducedX <- t(`dim<-`(reducedX, c(48, sample_size))) reducedX <- t(`dim<-`(reducedX, c(48, length(y))))
# set names for coefficient extraction from linear fit # set names for coefficient extraction from linear fit
colnames(reducedX) <- as.vector(outer( colnames(reducedX) <- as.vector(outer(
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "." unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
@ -86,15 +93,15 @@ psqt[["P"]][c(1, 8), ] <- 0
### Validation by GAM fitted on reduced data ### Validation by GAM fitted on reduced data
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+"))) formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0) fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX))
summary(fit.gam) summary(fit.gam)
# compair estimates with mean as baseline and static human crafted evaluation (HCE) # compair estimates with mean as baseline and static human crafted evaluation (HCE)
rmse.base <- sqrt(mean((mean(y) - y)^2)) (rmse.base <- sqrt(mean((mean(y) - y)^2)))
y.hce <- Rchess::HCE(fens) y.hce <- Rchess::HCE(fens)
rmse.hce <- sqrt(mean((y.hce - y)^2)) (rmse.hce <- sqrt(mean((y.hce - y)^2)))
y.hat <- predict(fit.gam, newdata = data.frame(reducedX)) y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
rmse.hat <- sqrt(mean((y.hat - y)^2)) (rmse.hat <- sqrt(mean((y.hat - y)^2)))
# Also extract R^2 (eval by hand or get from models) # Also extract R^2 (eval by hand or get from models)
(r.sq.lm <- summary(fit)$r.squared) (r.sq.lm <- summary(fit)$r.squared)

View File

@ -10,24 +10,55 @@
#include "search.h" #include "search.h"
#include "uci.h" #include "uci.h"
static const std::string usage{"usage: pgn2fen [--scored] [<input>]"}; static const std::string usage{"usage: pgn2fen [--scored] [--rating <rating>] [--ply <ply>] [<input>]"};
// Convert PGN (Portable Game Notation) input stream to single FENs // Convert PGN (Portable Game Notation) input stream to single FENs
// streamed to stdout // streamed to stdout
void pgn2fen(std::istream& input, const bool only_scored) { void pgn2fen(
std::istream& input,
const bool scored,
const unsigned long rating,
const unsigned long ply
) {
// Instantiate Boards, the start of every game as well as the current state // Instantiate Boards, the start of every game as well as the current state
// of the Board while processing a PGN game // of the Board while processing a PGN game
Board startpos, pos; Board startpos, pos;
// Parse white and black ELO ratings
unsigned long whiteElo = 0;
unsigned long blackElo = 0;
// Read input line by line // Read input line by line
std::string line; std::string line;
while (std::getline(input, line)) { while (std::getline(input, line)) {
// Skip empty and metadata lines (every PGN game starts with "<nr>.") // read rating metadata lines
if (rating != static_cast<unsigned long>(-1)) {
// [WhiteElo "1111"]
// [BlackElo "999"]
try {
if (line.rfind("[WhiteElo \"", 0) != std::string::npos) {
whiteElo = std::stoul(line.substr(11));
} else if (line.rfind("[BlackElo \"", 0) != std::string::npos) {
blackElo = std::stoul(line.substr(11));
}
} catch (...) {
std::cerr << "ERROR: Parsing player rating metadata '" << line
<< "' failed." << std::endl;
break;
}
}
// Skip empty and further metadata lines (every PGN game starts with "<nr>.")
if (line.empty() || line.front() == '[') { if (line.empty() || line.front() == '[') {
continue; continue;
} }
// In case of rating requested, only parse game when rating is detected
if (rating != static_cast<unsigned long>(-1)
&& (whiteElo < rating || blackElo < rating)) {
continue;
}
// Reset position to the start position, every game starts here! // Reset position to the start position, every game starts here!
pos = startpos; pos = startpos;
@ -36,7 +67,7 @@ void pgn2fen(std::istream& input, const bool only_scored) {
std::string count, san, token, eval; std::string count, san, token, eval;
while (game >> count >> san >> token) { while (game >> count >> san >> token) {
// Consume/Parse PGN comments // Consume/Parse PGN comments
if (only_scored) { if (scored) {
// consume the comment and search for an evaluation // consume the comment and search for an evaluation
bool has_score = false; bool has_score = false;
while (game >> token) { while (game >> token) {
@ -65,21 +96,40 @@ void pgn2fen(std::istream& input, const bool only_scored) {
bool parseError = false; bool parseError = false;
Move move = UCI::parseSAN(san, pos, parseError); Move move = UCI::parseSAN(san, pos, parseError);
if (parseError) { if (parseError) {
std::cerr << "[ Error ] Parsing '" << san << "' at position '" std::cerr << "ERROR: Parsing '" << san << "' at position '"
<< pos.fen() << "' failed." << std::endl; << pos.fen() << "' failed." << std::endl;
break;
} }
move = pos.isLegal(move); // validate legality and extend move info move = pos.isLegal(move); // validate legality and extend move info
if (move) { if (move) {
pos.make(move); pos.make(move);
} else { } else {
std::cerr << "[ Error ] Encountered illegal move '" << san std::cerr << "ERROR: Encountered illegal move '" << san
<< " (" << move << " (" << move
<< ") ' at position '" << pos.fen() << "'." << std::endl; << ") ' at position '" << pos.fen() << "'." << std::endl;
break; break;
} }
// Skip positions with too small ply count
if (pos.plyCount() < ply) {
continue;
}
// Write positions // Write positions
if (only_scored) { if (scored && rating != static_cast<unsigned long>(-1)) {
// Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions
if (eval.length() && eval[0] == '#') {
continue;
}
// Otherwise, classic eval score to be parsed in centipawns
std::cout << pos.fen() << "; " << eval << "; "
<< whiteElo << "; " << blackElo << '\n';
} else if (rating != static_cast<unsigned long>(-1)) {
// Otherwise, classic eval score to be parsed in centipawns
std::cout << pos.fen() << "; " << whiteElo << "; " << blackElo << '\n';
} else if (scored) {
// Ingore "check mate in" scores (not relevant for eval training) // Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate // Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions // was overlooked, leading to new positions
@ -92,52 +142,75 @@ void pgn2fen(std::istream& input, const bool only_scored) {
// Write only the position FEN // Write only the position FEN
std::cout << pos.fen() << '\n'; std::cout << pos.fen() << '\n';
} }
} }
// Reset ELO after every game to ensure that games without an elo
// metadata tag don't get the wrong rating from a previous game
whiteElo = 0;
blackElo = 0;
} }
} }
int main(int argn, char* argv[]) { int main(int argn, char* argv[]) {
// Setup control variables // Setup control variables
bool only_scored = false; bool scored = false;
unsigned long rating = -1;
unsigned long ply = 0;
// unsigned min_rating = 0;
std::string file = ""; std::string file = "";
// Parse command arguments // Parse command arguments
switch (argn) { for (int i = 1; i < argn; ++i) {
case 1: if (std::string("--scored") == argv[i]) {
break; scored = true;
case 2: } else if (std::string("--rating") == argv[i]) {
if (std::string("--scored") == argv[1]) { if (i + 1 < argn) {
only_scored = true; try {
rating = std::stoul(argv[++i]);
} catch (...) {
std::cerr << "ERROR: illegal --rating argument " << argv[i] << std::endl;
std::cout << usage << std::endl;
return 1;
}
} else { } else {
file = argv[1];
}
break;
case 3:
if (std::string("--scored") != argv[1]) {
std::cout << usage << std::endl; std::cout << usage << std::endl;
return 1; return 1;
} }
only_scored = true; } else if (std::string("--ply") == argv[i]) {
file = argv[2]; if (i + 1 < argn) {
break; try {
default: ply = std::stoul(argv[++i]);
} catch (...) {
std::cerr << "ERROR: illegal --ply argument " << argv[i] << std::endl;
std::cout << usage << std::endl;
return 1;
}
} else {
std::cout << usage << std::endl;
return 1;
}
} else if (file != "") {
file = argv[i];
} else {
std::cout << usage << std::endl; std::cout << usage << std::endl;
return 1; return 1;
}
} }
// Invoke converter ether with file input or stdin // Invoke converter ether with file input or stdin
if (file == "") { if (file == "") {
pgn2fen(std::cin, only_scored); pgn2fen(std::cin, scored, rating, ply);
} else { } else {
// Open input file // Open input file
std::ifstream input(file); std::ifstream input(file);
if (!input) { if (!input) {
std::cerr << "Error opening '" << file << "'" << std::endl; std::cerr << "ERROR: opening '" << file << "' failed" << std::endl;
return 1; return 1;
} }
pgn2fen(input, only_scored); pgn2fen(input, scored, rating, ply);
} }
return 0; return 0;

View File

@ -4,6 +4,10 @@
# in November 2023 # in November 2023
data=lichess_db_standard_rated_2023-11 data=lichess_db_standard_rated_2023-11
# Minimum "ELO" rating of black and white players
min_rating=2000
min_ply=20
# Check if file exists and download iff not # Check if file exists and download iff not
if [ -f "${data}.fen" ]; then if [ -f "${data}.fen" ]; then
echo "File '${data}.fen' already exists, assuming job already done." echo "File '${data}.fen' already exists, assuming job already done."
@ -21,11 +25,16 @@ else
# the PGN data base into a list of FEN strings while filtering only # the PGN data base into a list of FEN strings while filtering only
# positions with evaluation. The `--scored` parameter specifies to extract # positions with evaluation. The `--scored` parameter specifies to extract
# a position evaluation from the PGN and ONLY write positions with scores. # a position evaluation from the PGN and ONLY write positions with scores.
# That is, positions without a score are removed! # That is, positions without a score are removed! Parameter `--rating`
# filters games where both players have at least the minimum rating, and
# `--ply` specifies to only consider positions from `$min_ply` onward.
if [ -f "${data}.pgn.zst" ]; then if [ -f "${data}.pgn.zst" ]; then
zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen zstdcat ${data}.pgn.zst \
| ./pgn2fen --scored --rating $min_rating --ply $min_ply \
> ${data}.fen
else else
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \ wget -qO- https://database.lichess.org/standard/${data}.pgn.zst | zstdcat \
| zstdcat | ./pgn2fen --scored > ${data}.fen | ./pgn2fen --scored --rating $min_rating --ply $min_ply \
> ${data}.fen
fi fi
fi fi

View File

@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand # Generated by roxygen2: do not edit by hand
S3method(merge,matmul)
export("%<-%") export("%<-%")
export("%x_1%") export("%x_1%")
export("%x_2%") export("%x_2%")