tensor_predictors/dataAnalysis/chess/chess.R

176 lines
6.2 KiB
R
Raw Normal View History

library(tensorPredictors)
library(Rchess)
library(mgcv) # for `gam()` (Generalized Additive Model)
source("./gmlm_chess.R")
################################################################################
### Fitting the GMLM mixture model ###
################################################################################
# Data set file name of chess positions with Stockfish [https://stockfishchess.org]
# evaluation scores (downloaded and processed by `./preprocessing.sh` from the
# lichess data base [https://database.lichess.org/])
data_set <- "lichess_db_standard_rated_2023-11.fen"
# Function to draw samples `X` form the chess position `data_set` conditioned on
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
data_gen <- function(batch_size, score_min, score_max) {
2023-12-19 13:13:49 +00:00
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
}
# Invoke specialized GMLM optimization routine for chess data
fit.gmlm <- gmlm_chess(data_gen)
2023-12-19 13:13:49 +00:00
################################################################################
### Reduction Interpretation and Validation ###
2023-12-19 13:13:49 +00:00
################################################################################
# load last save point (includes reduction as `betas`)
save_point <- sort(list.files(
".",
pattern = "save_point_[0-9]*\\.Rdata",
full.names = TRUE
), decreasing = TRUE)[[1]]
load(save_point)
### Construct PSQT (Piece SQuare Tables) from reduction `betas`
sample_size <- 100000
# Sample a new position data set for fitting a linear model to conbine different
# reduction directions into a per piece PSQT matrix
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
# extract stockfish (non-static) position evaluation
2023-12-19 13:13:49 +00:00
y <- attr(fens, "scores")
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
X <- Rchess::fen2int(fens)
# Compute reduction
reducedX <- Reduce(rbind, Map(function(piece) {
# "condition" on piece, that is to extract the current mixture component
X <- X[, , piece, ]
# reduce mixture component
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
}, 1:12))
# Convert memory layout to contain vectorized observations in rows
reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
# set names for coefficient extraction from linear fit
colnames(reducedX) <- as.vector(outer(
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
))
# Estimate PSQT linear combination weights from reduced sample (exclude dead
# draw positions, that is "score = 0". This are approx 5% of all positions)
fit <- lm(y ~ ., data = data.frame(y = y, reducedX), subset = y != 0.0)
2023-12-19 13:13:49 +00:00
summary(fit)
# Translate reduction with weighting estimate into PSQTs
psqt <- Map(function(piece) {
# reduction column names corresponding to the current white piece (upper case)
piece <- toupper(piece)
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
# Whites PSQT
psqt_white <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
dim(psqt_white) <- c(8, 8)
# the same for black
piece <- tolower(piece)
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
psqt_black <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
dim(psqt_black) <- c(8, 8)
# Combine into shared PSQT from whites point of view
psqt_white - psqt_black[8:1, ]
}, c("P", "N", "B", "R", "Q", "K"))
# finish by enforcing the pawn constraint (irrelevant for validation, the
# corresponding values in an encoded position is always zero)
psqt[["P"]][c(1, 8), ] <- 0
### Validation by GAM fitted on reduced data
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
summary(fit.gam)
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
rmse.base <- sqrt(mean((mean(y) - y)^2))
y.hce <- Rchess::HCE(fens)
rmse.hce <- sqrt(mean((y.hce - y)^2))
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
rmse.hat <- sqrt(mean((y.hat - y)^2))
# Also extract R^2 (eval by hand or get from models)
(r.sq.lm <- summary(fit)$r.squared)
(r.sq.gam <- summary(fit.gam)$r.sq)
(r.sq.hce <- 1 - (rmse.hce / rmse.base)^2)
2023-12-19 13:13:49 +00:00
################################################################################
### Generate LaTeX PSQT plot ###
################################################################################
if (FALSE) {
2023-12-19 13:13:49 +00:00
sink("psqt.tex")
2023-12-19 13:13:49 +00:00
cat("% Authomatically generated by `dataAnalysis/chess.R`
\\documentclass{standalone}
2023-12-19 13:13:49 +00:00
\\usepackage[LSB, T1]{fontenc}
\\usepackage{chessboard}
\\usepackage{skak}
\\usepackage{tikz}
\\usepackage{amsmath}
\\usepackage{xcolor}
2023-12-19 13:13:49 +00:00
\\setboardfontencoding{LSB}
\\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
")
cat(paste0("\\definecolor{col", 1:128, "}{HTML}{",
mapply(`[`, strsplit(hcl.colors(128, "Blue-Red 3", rev = TRUE), "#"), 2),
"}"
))
2023-12-19 13:13:49 +00:00
cat("
2023-12-19 13:13:49 +00:00
\\begin{document}
\\begin{tikzpicture}
2023-12-19 13:13:49 +00:00
\\coordinate (pawn) at (0, 0);
\\coordinate (knight) at (5, 0);
\\coordinate (bishop) at (10, 0);
\\coordinate (rook) at (0, -5.2);
\\coordinate (queen) at (5, -5.2);
\\coordinate (king) at (10, -5.2);
2023-12-19 13:13:49 +00:00
")
2023-12-19 13:13:49 +00:00
local({
zlim <- c(-1, 1) * max(abs(unlist(psqt, use.names = FALSE)))
breaks <- seq(zlim[1], zlim[2], len = 129)
pieces <- c("pawn", "knight", "bishop", "rook", "queen", "king")
for (i in seq_along(psqt)) {
cat(paste0("\\node (", pieces[i], ") at (", pieces[i], ") {\\chessboard[", paste0(
"color=col", as.integer(cut(psqt[[i]], breaks)),
",colorbackfield=", outer(8:1, letters[1:8], function(r, f) paste0(f, r)),
collapse=","
), "]};\n"))
}
})
2023-12-19 13:13:49 +00:00
cat("
\\node[anchor = north, yshift = -0.4em] at (pawn.north) {Pawn};
\\node[anchor = north, yshift = -0.4em] at (knight.north) {Knight};
\\node[anchor = north, yshift = -0.4em] at (bishop.north) {Bishop};
\\node[anchor = north, yshift = -0.4em] at (rook.north) {Rook};
\\node[anchor = north, yshift = -0.4em] at (queen.north) {Queen};
\\node[anchor = north, yshift = -0.4em] at (king.north) {King};
2023-12-19 13:13:49 +00:00
\\end{tikzpicture}
\\end{document}
")
2023-12-19 13:13:49 +00:00
sink()
2023-12-19 13:13:49 +00:00
}