update: data download and fen processing

This commit is contained in:
Daniel Kapla 2025-02-07 13:46:10 +01:00
parent b1f25b89da
commit 917c81f37f
7 changed files with 146 additions and 53 deletions

View File

@ -46,8 +46,8 @@ isInsufficient <- function(positions) {
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
#' with scores filtered to be in in the range `score_min` to `score_max`.
#'
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, draw = TRUE, min_ply_count = 10L, white_only = TRUE) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only)
}
#' Human Crafted Evaluation

View File

@ -72,8 +72,8 @@ BEGIN_RCPP
END_RCPP
}
// data_gen
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
Rcpp::List data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const bool draw, const int min_ply_count, const bool white_only);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP drawSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
@ -82,9 +82,10 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
Rcpp::traits::input_parameter< const bool >::type draw(drawSEXP);
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only));
return rcpp_result_gen;
END_RCPP
}
@ -255,7 +256,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 8},
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},

View File

@ -14,12 +14,13 @@
//' with scores filtered to be in in the range `score_min` to `score_max`.
//'
// [[Rcpp::export(name = "data.gen", rng = true)]]
Rcpp::CharacterVector data_gen(
Rcpp::List data_gen(
const std::string& file,
const int sample_size,
const float score_min = -5.0,
const float score_max = +5.0,
const bool quiet = false,
const bool draw = true,
const int min_ply_count = 10,
const bool white_only = true
) {
@ -51,7 +52,7 @@ Rcpp::CharacterVector data_gen(
}
// Allocate output sample
Rcpp::CharacterVector _sample(sample_size);
Rcpp::CharacterVector _fens(sample_size);
Rcpp::NumericVector _scores(sample_size);
// Read and filter lines from FEN data base file
@ -106,15 +107,16 @@ Rcpp::CharacterVector data_gen(
// Reject / Filter samples
if (((int)pos.plyCount() < min_ply_count) // early positions
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|| (score < score_min || score_max <= score) // scores out of slice
|| (quiet && !pos.isQuiet())) // quiet positions
|| (score < score_min || score_max < score) // scores out of slice
|| (quiet && !pos.isQuiet()) // quiet positions
|| (!draw && score == 0.0)) // drawn positions
{
reject_count++;
continue;
}
// Everythings succeeded and ge got an appropriate sample in requested range
_sample[sample_count] = fen;
_fens[sample_count] = fen;
_scores[sample_count] = score;
++sample_count;
@ -130,8 +132,8 @@ Rcpp::CharacterVector data_gen(
}
}
// Set scores as attribute to position sample
_sample.attr("scores") = _scores;
return _sample;
return Rcpp::List::create(
Rcpp::Named("fens") = _fens,
Rcpp::Named("scores") = _scores
);
}

View File

@ -1,6 +1,5 @@
library(tensorPredictors)
library(Rchess)
library(mgcv) # for `gam()` (Generalized Additive Model)
source("./gmlm_chess.R")
@ -16,9 +15,12 @@ data_set <- "lichess_db_standard_rated_2023-11.fen"
# Function to draw samples `X` form the chess position `data_set` conditioned on
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
data_gen <- function(batch_size, score_min, score_max) {
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)
pos <- Rchess::fen2int(data$fens)
structure(pos, scores = data$scores)
}
# Invoke specialized GMLM optimization routine for chess data
fit.gmlm <- gmlm_chess(data_gen)
@ -26,6 +28,7 @@ fit.gmlm <- gmlm_chess(data_gen)
################################################################################
### Reduction Interpretation and Validation ###
################################################################################
library(mgcv) # for `gam()` (Generalized Additive Model)
# load last save point (includes reduction as `betas`)
save_point <- sort(list.files(
@ -43,6 +46,10 @@ sample_size <- 100000
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
# extract stockfish (non-static) position evaluation
y <- attr(fens, "scores")
# remove poitions with exact draw evalualtion
draws <- which(y == 0.0)
y <- y[-draws]
fens <- fens[-draws]
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
X <- Rchess::fen2int(fens)
@ -54,7 +61,7 @@ reducedX <- Reduce(rbind, Map(function(piece) {
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
}, 1:12))
# Convert memory layout to contain vectorized observations in rows
reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
reducedX <- t(`dim<-`(reducedX, c(48, length(y))))
# set names for coefficient extraction from linear fit
colnames(reducedX) <- as.vector(outer(
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
@ -86,15 +93,15 @@ psqt[["P"]][c(1, 8), ] <- 0
### Validation by GAM fitted on reduced data
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX))
summary(fit.gam)
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
rmse.base <- sqrt(mean((mean(y) - y)^2))
(rmse.base <- sqrt(mean((mean(y) - y)^2)))
y.hce <- Rchess::HCE(fens)
rmse.hce <- sqrt(mean((y.hce - y)^2))
(rmse.hce <- sqrt(mean((y.hce - y)^2)))
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
rmse.hat <- sqrt(mean((y.hat - y)^2))
(rmse.hat <- sqrt(mean((y.hat - y)^2)))
# Also extract R^2 (eval by hand or get from models)
(r.sq.lm <- summary(fit)$r.squared)

View File

@ -10,24 +10,55 @@
#include "search.h"
#include "uci.h"
static const std::string usage{"usage: pgn2fen [--scored] [<input>]"};
static const std::string usage{"usage: pgn2fen [--scored] [--rating <rating>] [--ply <ply>] [<input>]"};
// Convert PGN (Portable Game Notation) input stream to single FENs
// streamed to stdout
void pgn2fen(std::istream& input, const bool only_scored) {
void pgn2fen(
std::istream& input,
const bool scored,
const unsigned long rating,
const unsigned long ply
) {
// Instantiate Boards, the start of every game as well as the current state
// of the Board while processing a PGN game
Board startpos, pos;
// Parse white and black ELO ratings
unsigned long whiteElo = 0;
unsigned long blackElo = 0;
// Read input line by line
std::string line;
while (std::getline(input, line)) {
// Skip empty and metadata lines (every PGN game starts with "<nr>.")
// read rating metadata lines
if (rating != static_cast<unsigned long>(-1)) {
// [WhiteElo "1111"]
// [BlackElo "999"]
try {
if (line.rfind("[WhiteElo \"", 0) != std::string::npos) {
whiteElo = std::stoul(line.substr(11));
} else if (line.rfind("[BlackElo \"", 0) != std::string::npos) {
blackElo = std::stoul(line.substr(11));
}
} catch (...) {
std::cerr << "ERROR: Parsing player rating metadata '" << line
<< "' failed." << std::endl;
break;
}
}
// Skip empty and further metadata lines (every PGN game starts with "<nr>.")
if (line.empty() || line.front() == '[') {
continue;
}
// In case of rating requested, only parse game when rating is detected
if (rating != static_cast<unsigned long>(-1)
&& (whiteElo < rating || blackElo < rating)) {
continue;
}
// Reset position to the start position, every game starts here!
pos = startpos;
@ -36,7 +67,7 @@ void pgn2fen(std::istream& input, const bool only_scored) {
std::string count, san, token, eval;
while (game >> count >> san >> token) {
// Consume/Parse PGN comments
if (only_scored) {
if (scored) {
// consume the comment and search for an evaluation
bool has_score = false;
while (game >> token) {
@ -65,21 +96,40 @@ void pgn2fen(std::istream& input, const bool only_scored) {
bool parseError = false;
Move move = UCI::parseSAN(san, pos, parseError);
if (parseError) {
std::cerr << "[ Error ] Parsing '" << san << "' at position '"
std::cerr << "ERROR: Parsing '" << san << "' at position '"
<< pos.fen() << "' failed." << std::endl;
break;
}
move = pos.isLegal(move); // validate legality and extend move info
if (move) {
pos.make(move);
} else {
std::cerr << "[ Error ] Encountered illegal move '" << san
std::cerr << "ERROR: Encountered illegal move '" << san
<< " (" << move
<< ") ' at position '" << pos.fen() << "'." << std::endl;
break;
}
// Skip positions with too small ply count
if (pos.plyCount() < ply) {
continue;
}
// Write positions
if (only_scored) {
if (scored && rating != static_cast<unsigned long>(-1)) {
// Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions
if (eval.length() && eval[0] == '#') {
continue;
}
// Otherwise, classic eval score to be parsed in centipawns
std::cout << pos.fen() << "; " << eval << "; "
<< whiteElo << "; " << blackElo << '\n';
} else if (rating != static_cast<unsigned long>(-1)) {
// Otherwise, classic eval score to be parsed in centipawns
std::cout << pos.fen() << "; " << whiteElo << "; " << blackElo << '\n';
} else if (scored) {
// Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions
@ -92,52 +142,75 @@ void pgn2fen(std::istream& input, const bool only_scored) {
// Write only the position FEN
std::cout << pos.fen() << '\n';
}
}
// Reset ELO after every game to ensure that games without an elo
// metadata tag don't get the wrong rating from a previous game
whiteElo = 0;
blackElo = 0;
}
}
int main(int argn, char* argv[]) {
// Setup control variables
bool only_scored = false;
bool scored = false;
unsigned long rating = -1;
unsigned long ply = 0;
// unsigned min_rating = 0;
std::string file = "";
// Parse command arguments
switch (argn) {
case 1:
break;
case 2:
if (std::string("--scored") == argv[1]) {
only_scored = true;
for (int i = 1; i < argn; ++i) {
if (std::string("--scored") == argv[i]) {
scored = true;
} else if (std::string("--rating") == argv[i]) {
if (i + 1 < argn) {
try {
rating = std::stoul(argv[++i]);
} catch (...) {
std::cerr << "ERROR: illegal --rating argument " << argv[i] << std::endl;
std::cout << usage << std::endl;
return 1;
}
} else {
file = argv[1];
}
break;
case 3:
if (std::string("--scored") != argv[1]) {
std::cout << usage << std::endl;
return 1;
}
only_scored = true;
file = argv[2];
break;
default:
} else if (std::string("--ply") == argv[i]) {
if (i + 1 < argn) {
try {
ply = std::stoul(argv[++i]);
} catch (...) {
std::cerr << "ERROR: illegal --ply argument " << argv[i] << std::endl;
std::cout << usage << std::endl;
return 1;
}
} else {
std::cout << usage << std::endl;
return 1;
}
} else if (file != "") {
file = argv[i];
} else {
std::cout << usage << std::endl;
return 1;
}
}
// Invoke converter ether with file input or stdin
if (file == "") {
pgn2fen(std::cin, only_scored);
pgn2fen(std::cin, scored, rating, ply);
} else {
// Open input file
std::ifstream input(file);
if (!input) {
std::cerr << "Error opening '" << file << "'" << std::endl;
std::cerr << "ERROR: opening '" << file << "' failed" << std::endl;
return 1;
}
pgn2fen(input, only_scored);
pgn2fen(input, scored, rating, ply);
}
return 0;

View File

@ -4,6 +4,10 @@
# in November 2023
data=lichess_db_standard_rated_2023-11
# Minimum "ELO" rating of black and white players
min_rating=2000
min_ply=20
# Check if file exists and download iff not
if [ -f "${data}.fen" ]; then
echo "File '${data}.fen' already exists, assuming job already done."
@ -21,11 +25,16 @@ else
# the PGN data base into a list of FEN strings while filtering only
# positions with evaluation. The `--scored` parameter specifies to extract
# a position evaluation from the PGN and ONLY write positions with scores.
# That is, positions without a score are removed!
# That is, positions without a score are removed! Parameter `--rating`
# filters games where both players have at least the minimum rating, and
# `--ply` specifies to only consider positions from `$min_ply` onward.
if [ -f "${data}.pgn.zst" ]; then
zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen
zstdcat ${data}.pgn.zst \
| ./pgn2fen --scored --rating $min_rating --ply $min_ply \
> ${data}.fen
else
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \
| zstdcat | ./pgn2fen --scored > ${data}.fen
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst | zstdcat \
| ./pgn2fen --scored --rating $min_rating --ply $min_ply \
> ${data}.fen
fi
fi

View File

@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand
S3method(merge,matmul)
export("%<-%")
export("%x_1%")
export("%x_2%")