update: data download and fen processing
This commit is contained in:
parent
b1f25b89da
commit
917c81f37f
|
@ -46,8 +46,8 @@ isInsufficient <- function(positions) {
|
|||
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
||||
#' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||
#'
|
||||
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
|
||||
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
|
||||
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, draw = TRUE, min_ply_count = 10L, white_only = TRUE) {
|
||||
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only)
|
||||
}
|
||||
|
||||
#' Human Crafted Evaluation
|
||||
|
|
|
@ -72,8 +72,8 @@ BEGIN_RCPP
|
|||
END_RCPP
|
||||
}
|
||||
// data_gen
|
||||
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
|
||||
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
|
||||
Rcpp::List data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const bool draw, const int min_ply_count, const bool white_only);
|
||||
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP drawSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||
|
@ -82,9 +82,10 @@ BEGIN_RCPP
|
|||
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
|
||||
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type draw(drawSEXP);
|
||||
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
|
||||
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, draw, min_ply_count, white_only));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
|
@ -255,7 +256,7 @@ static const R_CallMethodDef CallEntries[] = {
|
|||
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
|
||||
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
|
||||
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
|
||||
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
|
||||
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 8},
|
||||
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
|
||||
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
|
||||
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
|
||||
|
|
|
@ -14,12 +14,13 @@
|
|||
//' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||
//'
|
||||
// [[Rcpp::export(name = "data.gen", rng = true)]]
|
||||
Rcpp::CharacterVector data_gen(
|
||||
Rcpp::List data_gen(
|
||||
const std::string& file,
|
||||
const int sample_size,
|
||||
const float score_min = -5.0,
|
||||
const float score_max = +5.0,
|
||||
const bool quiet = false,
|
||||
const bool draw = true,
|
||||
const int min_ply_count = 10,
|
||||
const bool white_only = true
|
||||
) {
|
||||
|
@ -51,7 +52,7 @@ Rcpp::CharacterVector data_gen(
|
|||
}
|
||||
|
||||
// Allocate output sample
|
||||
Rcpp::CharacterVector _sample(sample_size);
|
||||
Rcpp::CharacterVector _fens(sample_size);
|
||||
Rcpp::NumericVector _scores(sample_size);
|
||||
|
||||
// Read and filter lines from FEN data base file
|
||||
|
@ -106,15 +107,16 @@ Rcpp::CharacterVector data_gen(
|
|||
// Reject / Filter samples
|
||||
if (((int)pos.plyCount() < min_ply_count) // early positions
|
||||
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|
||||
|| (score < score_min || score_max <= score) // scores out of slice
|
||||
|| (quiet && !pos.isQuiet())) // quiet positions
|
||||
|| (score < score_min || score_max < score) // scores out of slice
|
||||
|| (quiet && !pos.isQuiet()) // quiet positions
|
||||
|| (!draw && score == 0.0)) // drawn positions
|
||||
{
|
||||
reject_count++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Everythings succeeded and ge got an appropriate sample in requested range
|
||||
_sample[sample_count] = fen;
|
||||
_fens[sample_count] = fen;
|
||||
_scores[sample_count] = score;
|
||||
++sample_count;
|
||||
|
||||
|
@ -130,8 +132,8 @@ Rcpp::CharacterVector data_gen(
|
|||
}
|
||||
}
|
||||
|
||||
// Set scores as attribute to position sample
|
||||
_sample.attr("scores") = _scores;
|
||||
|
||||
return _sample;
|
||||
return Rcpp::List::create(
|
||||
Rcpp::Named("fens") = _fens,
|
||||
Rcpp::Named("scores") = _scores
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
library(tensorPredictors)
|
||||
library(Rchess)
|
||||
library(mgcv) # for `gam()` (Generalized Additive Model)
|
||||
|
||||
source("./gmlm_chess.R")
|
||||
|
||||
|
@ -16,9 +15,12 @@ data_set <- "lichess_db_standard_rated_2023-11.fen"
|
|||
# Function to draw samples `X` form the chess position `data_set` conditioned on
|
||||
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
|
||||
data_gen <- function(batch_size, score_min, score_max) {
|
||||
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
|
||||
data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)
|
||||
pos <- Rchess::fen2int(data$fens)
|
||||
structure(pos, scores = data$scores)
|
||||
}
|
||||
|
||||
|
||||
# Invoke specialized GMLM optimization routine for chess data
|
||||
fit.gmlm <- gmlm_chess(data_gen)
|
||||
|
||||
|
@ -26,6 +28,7 @@ fit.gmlm <- gmlm_chess(data_gen)
|
|||
################################################################################
|
||||
### Reduction Interpretation and Validation ###
|
||||
################################################################################
|
||||
library(mgcv) # for `gam()` (Generalized Additive Model)
|
||||
|
||||
# load last save point (includes reduction as `betas`)
|
||||
save_point <- sort(list.files(
|
||||
|
@ -43,6 +46,10 @@ sample_size <- 100000
|
|||
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
|
||||
# extract stockfish (non-static) position evaluation
|
||||
y <- attr(fens, "scores")
|
||||
# remove poitions with exact draw evalualtion
|
||||
draws <- which(y == 0.0)
|
||||
y <- y[-draws]
|
||||
fens <- fens[-draws]
|
||||
|
||||
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
|
||||
X <- Rchess::fen2int(fens)
|
||||
|
@ -54,7 +61,7 @@ reducedX <- Reduce(rbind, Map(function(piece) {
|
|||
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
|
||||
}, 1:12))
|
||||
# Convert memory layout to contain vectorized observations in rows
|
||||
reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
|
||||
reducedX <- t(`dim<-`(reducedX, c(48, length(y))))
|
||||
# set names for coefficient extraction from linear fit
|
||||
colnames(reducedX) <- as.vector(outer(
|
||||
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
|
||||
|
@ -86,15 +93,15 @@ psqt[["P"]][c(1, 8), ] <- 0
|
|||
|
||||
### Validation by GAM fitted on reduced data
|
||||
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
|
||||
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
|
||||
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX))
|
||||
summary(fit.gam)
|
||||
|
||||
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
|
||||
rmse.base <- sqrt(mean((mean(y) - y)^2))
|
||||
(rmse.base <- sqrt(mean((mean(y) - y)^2)))
|
||||
y.hce <- Rchess::HCE(fens)
|
||||
rmse.hce <- sqrt(mean((y.hce - y)^2))
|
||||
(rmse.hce <- sqrt(mean((y.hce - y)^2)))
|
||||
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
|
||||
rmse.hat <- sqrt(mean((y.hat - y)^2))
|
||||
(rmse.hat <- sqrt(mean((y.hat - y)^2)))
|
||||
|
||||
# Also extract R^2 (eval by hand or get from models)
|
||||
(r.sq.lm <- summary(fit)$r.squared)
|
||||
|
|
|
@ -10,24 +10,55 @@
|
|||
#include "search.h"
|
||||
#include "uci.h"
|
||||
|
||||
static const std::string usage{"usage: pgn2fen [--scored] [<input>]"};
|
||||
static const std::string usage{"usage: pgn2fen [--scored] [--rating <rating>] [--ply <ply>] [<input>]"};
|
||||
|
||||
// Convert PGN (Portable Game Notation) input stream to single FENs
|
||||
// streamed to stdout
|
||||
void pgn2fen(std::istream& input, const bool only_scored) {
|
||||
|
||||
void pgn2fen(
|
||||
std::istream& input,
|
||||
const bool scored,
|
||||
const unsigned long rating,
|
||||
const unsigned long ply
|
||||
) {
|
||||
// Instantiate Boards, the start of every game as well as the current state
|
||||
// of the Board while processing a PGN game
|
||||
Board startpos, pos;
|
||||
|
||||
// Parse white and black ELO ratings
|
||||
unsigned long whiteElo = 0;
|
||||
unsigned long blackElo = 0;
|
||||
|
||||
// Read input line by line
|
||||
std::string line;
|
||||
while (std::getline(input, line)) {
|
||||
// Skip empty and metadata lines (every PGN game starts with "<nr>.")
|
||||
// read rating metadata lines
|
||||
if (rating != static_cast<unsigned long>(-1)) {
|
||||
// [WhiteElo "1111"]
|
||||
// [BlackElo "999"]
|
||||
try {
|
||||
if (line.rfind("[WhiteElo \"", 0) != std::string::npos) {
|
||||
whiteElo = std::stoul(line.substr(11));
|
||||
} else if (line.rfind("[BlackElo \"", 0) != std::string::npos) {
|
||||
blackElo = std::stoul(line.substr(11));
|
||||
}
|
||||
} catch (...) {
|
||||
std::cerr << "ERROR: Parsing player rating metadata '" << line
|
||||
<< "' failed." << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip empty and further metadata lines (every PGN game starts with "<nr>.")
|
||||
if (line.empty() || line.front() == '[') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// In case of rating requested, only parse game when rating is detected
|
||||
if (rating != static_cast<unsigned long>(-1)
|
||||
&& (whiteElo < rating || blackElo < rating)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reset position to the start position, every game starts here!
|
||||
pos = startpos;
|
||||
|
||||
|
@ -36,7 +67,7 @@ void pgn2fen(std::istream& input, const bool only_scored) {
|
|||
std::string count, san, token, eval;
|
||||
while (game >> count >> san >> token) {
|
||||
// Consume/Parse PGN comments
|
||||
if (only_scored) {
|
||||
if (scored) {
|
||||
// consume the comment and search for an evaluation
|
||||
bool has_score = false;
|
||||
while (game >> token) {
|
||||
|
@ -65,21 +96,40 @@ void pgn2fen(std::istream& input, const bool only_scored) {
|
|||
bool parseError = false;
|
||||
Move move = UCI::parseSAN(san, pos, parseError);
|
||||
if (parseError) {
|
||||
std::cerr << "[ Error ] Parsing '" << san << "' at position '"
|
||||
std::cerr << "ERROR: Parsing '" << san << "' at position '"
|
||||
<< pos.fen() << "' failed." << std::endl;
|
||||
break;
|
||||
}
|
||||
move = pos.isLegal(move); // validate legality and extend move info
|
||||
if (move) {
|
||||
pos.make(move);
|
||||
} else {
|
||||
std::cerr << "[ Error ] Encountered illegal move '" << san
|
||||
std::cerr << "ERROR: Encountered illegal move '" << san
|
||||
<< " (" << move
|
||||
<< ") ' at position '" << pos.fen() << "'." << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
// Skip positions with too small ply count
|
||||
if (pos.plyCount() < ply) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write positions
|
||||
if (only_scored) {
|
||||
if (scored && rating != static_cast<unsigned long>(-1)) {
|
||||
// Ingore "check mate in" scores (not relevant for eval training)
|
||||
// Do this after "make move" in situations where the check mate
|
||||
// was overlooked, leading to new positions
|
||||
if (eval.length() && eval[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
// Otherwise, classic eval score to be parsed in centipawns
|
||||
std::cout << pos.fen() << "; " << eval << "; "
|
||||
<< whiteElo << "; " << blackElo << '\n';
|
||||
} else if (rating != static_cast<unsigned long>(-1)) {
|
||||
// Otherwise, classic eval score to be parsed in centipawns
|
||||
std::cout << pos.fen() << "; " << whiteElo << "; " << blackElo << '\n';
|
||||
} else if (scored) {
|
||||
// Ingore "check mate in" scores (not relevant for eval training)
|
||||
// Do this after "make move" in situations where the check mate
|
||||
// was overlooked, leading to new positions
|
||||
|
@ -92,52 +142,75 @@ void pgn2fen(std::istream& input, const bool only_scored) {
|
|||
// Write only the position FEN
|
||||
std::cout << pos.fen() << '\n';
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Reset ELO after every game to ensure that games without an elo
|
||||
// metadata tag don't get the wrong rating from a previous game
|
||||
whiteElo = 0;
|
||||
blackElo = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main(int argn, char* argv[]) {
|
||||
// Setup control variables
|
||||
bool only_scored = false;
|
||||
bool scored = false;
|
||||
unsigned long rating = -1;
|
||||
unsigned long ply = 0;
|
||||
// unsigned min_rating = 0;
|
||||
std::string file = "";
|
||||
|
||||
// Parse command arguments
|
||||
switch (argn) {
|
||||
case 1:
|
||||
break;
|
||||
case 2:
|
||||
if (std::string("--scored") == argv[1]) {
|
||||
only_scored = true;
|
||||
for (int i = 1; i < argn; ++i) {
|
||||
if (std::string("--scored") == argv[i]) {
|
||||
scored = true;
|
||||
} else if (std::string("--rating") == argv[i]) {
|
||||
if (i + 1 < argn) {
|
||||
try {
|
||||
rating = std::stoul(argv[++i]);
|
||||
} catch (...) {
|
||||
std::cerr << "ERROR: illegal --rating argument " << argv[i] << std::endl;
|
||||
std::cout << usage << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
file = argv[1];
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
if (std::string("--scored") != argv[1]) {
|
||||
std::cout << usage << std::endl;
|
||||
return 1;
|
||||
}
|
||||
only_scored = true;
|
||||
file = argv[2];
|
||||
break;
|
||||
default:
|
||||
} else if (std::string("--ply") == argv[i]) {
|
||||
if (i + 1 < argn) {
|
||||
try {
|
||||
ply = std::stoul(argv[++i]);
|
||||
} catch (...) {
|
||||
std::cerr << "ERROR: illegal --ply argument " << argv[i] << std::endl;
|
||||
std::cout << usage << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
std::cout << usage << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else if (file != "") {
|
||||
file = argv[i];
|
||||
} else {
|
||||
std::cout << usage << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Invoke converter ether with file input or stdin
|
||||
if (file == "") {
|
||||
pgn2fen(std::cin, only_scored);
|
||||
pgn2fen(std::cin, scored, rating, ply);
|
||||
} else {
|
||||
// Open input file
|
||||
std::ifstream input(file);
|
||||
if (!input) {
|
||||
std::cerr << "Error opening '" << file << "'" << std::endl;
|
||||
std::cerr << "ERROR: opening '" << file << "' failed" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
pgn2fen(input, only_scored);
|
||||
pgn2fen(input, scored, rating, ply);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
# in November 2023
|
||||
data=lichess_db_standard_rated_2023-11
|
||||
|
||||
# Minimum "ELO" rating of black and white players
|
||||
min_rating=2000
|
||||
min_ply=20
|
||||
|
||||
# Check if file exists and download iff not
|
||||
if [ -f "${data}.fen" ]; then
|
||||
echo "File '${data}.fen' already exists, assuming job already done."
|
||||
|
@ -21,11 +25,16 @@ else
|
|||
# the PGN data base into a list of FEN strings while filtering only
|
||||
# positions with evaluation. The `--scored` parameter specifies to extract
|
||||
# a position evaluation from the PGN and ONLY write positions with scores.
|
||||
# That is, positions without a score are removed!
|
||||
# That is, positions without a score are removed! Parameter `--rating`
|
||||
# filters games where both players have at least the minimum rating, and
|
||||
# `--ply` specifies to only consider positions from `$min_ply` onward.
|
||||
if [ -f "${data}.pgn.zst" ]; then
|
||||
zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen
|
||||
zstdcat ${data}.pgn.zst \
|
||||
| ./pgn2fen --scored --rating $min_rating --ply $min_ply \
|
||||
> ${data}.fen
|
||||
else
|
||||
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \
|
||||
| zstdcat | ./pgn2fen --scored > ${data}.fen
|
||||
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst | zstdcat \
|
||||
| ./pgn2fen --scored --rating $min_rating --ply $min_ply \
|
||||
> ${data}.fen
|
||||
fi
|
||||
fi
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Generated by roxygen2: do not edit by hand
|
||||
|
||||
S3method(merge,matmul)
|
||||
export("%<-%")
|
||||
export("%x_1%")
|
||||
export("%x_2%")
|
||||
|
|
Loading…
Reference in New Issue