2023-12-13 17:46:17 +00:00
|
|
|
library(tensorPredictors)
|
|
|
|
library(Rchess)
|
2023-12-30 09:01:12 +00:00
|
|
|
library(mgcv) # for `gam()` (Generalized Additive Model)
|
2023-12-13 17:46:17 +00:00
|
|
|
|
|
|
|
source("./gmlm_chess.R")
|
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
################################################################################
|
|
|
|
### Fitting the GMLM mixture model ###
|
|
|
|
################################################################################
|
|
|
|
|
2023-12-13 17:46:17 +00:00
|
|
|
# Data set file name of chess positions with Stockfish [https://stockfishchess.org]
|
|
|
|
# evaluation scores (downloaded and processed by `./preprocessing.sh` from the
|
|
|
|
# lichess data base [https://database.lichess.org/])
|
|
|
|
data_set <- "lichess_db_standard_rated_2023-11.fen"
|
|
|
|
|
|
|
|
# Function to draw samples `X` form the chess position `data_set` conditioned on
|
|
|
|
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
|
|
|
|
data_gen <- function(batch_size, score_min, score_max) {
|
2023-12-19 13:13:49 +00:00
|
|
|
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
|
2023-12-13 17:46:17 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
# Invoke specialized GMLM optimization routine for chess data
|
2023-12-30 09:01:12 +00:00
|
|
|
fit.gmlm <- gmlm_chess(data_gen)
|
2023-12-13 17:46:17 +00:00
|
|
|
|
|
|
|
|
2023-12-19 13:13:49 +00:00
|
|
|
################################################################################
|
2023-12-30 09:01:12 +00:00
|
|
|
### Reduction Interpretation and Validation ###
|
2023-12-19 13:13:49 +00:00
|
|
|
################################################################################
|
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
# load last save point (includes reduction as `betas`)
|
|
|
|
save_point <- sort(list.files(
|
|
|
|
".",
|
|
|
|
pattern = "save_point_[0-9]*\\.Rdata",
|
|
|
|
full.names = TRUE
|
|
|
|
), decreasing = TRUE)[[1]]
|
|
|
|
load(save_point)
|
|
|
|
|
|
|
|
|
|
|
|
### Construct PSQT (Piece SQuare Tables) from reduction `betas`
|
|
|
|
sample_size <- 100000
|
|
|
|
# Sample a new position data set for fitting a linear model to conbine different
|
|
|
|
# reduction directions into a per piece PSQT matrix
|
|
|
|
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
|
|
|
|
# extract stockfish (non-static) position evaluation
|
2023-12-19 13:13:49 +00:00
|
|
|
y <- attr(fens, "scores")
|
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
|
|
|
|
X <- Rchess::fen2int(fens)
|
|
|
|
# Compute reduction
|
|
|
|
reducedX <- Reduce(rbind, Map(function(piece) {
|
|
|
|
# "condition" on piece, that is to extract the current mixture component
|
|
|
|
X <- X[, , piece, ]
|
|
|
|
# reduce mixture component
|
|
|
|
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
|
|
|
|
}, 1:12))
|
|
|
|
# Convert memory layout to contain vectorized observations in rows
|
|
|
|
reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
|
|
|
|
# set names for coefficient extraction from linear fit
|
|
|
|
colnames(reducedX) <- as.vector(outer(
|
|
|
|
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
|
|
|
|
))
|
|
|
|
|
|
|
|
# Estimate PSQT linear combination weights from reduced sample (exclude dead
|
|
|
|
# draw positions, that is "score = 0". This are approx 5% of all positions)
|
|
|
|
fit <- lm(y ~ ., data = data.frame(y = y, reducedX), subset = y != 0.0)
|
2023-12-19 13:13:49 +00:00
|
|
|
summary(fit)
|
2023-12-30 09:01:12 +00:00
|
|
|
# Translate reduction with weighting estimate into PSQTs
|
|
|
|
psqt <- Map(function(piece) {
|
|
|
|
# reduction column names corresponding to the current white piece (upper case)
|
|
|
|
piece <- toupper(piece)
|
|
|
|
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
|
|
|
|
# Whites PSQT
|
|
|
|
psqt_white <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
|
|
|
|
dim(psqt_white) <- c(8, 8)
|
|
|
|
# the same for black
|
|
|
|
piece <- tolower(piece)
|
|
|
|
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
|
|
|
|
psqt_black <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
|
|
|
|
dim(psqt_black) <- c(8, 8)
|
|
|
|
# Combine into shared PSQT from whites point of view
|
|
|
|
psqt_white - psqt_black[8:1, ]
|
|
|
|
}, c("P", "N", "B", "R", "Q", "K"))
|
|
|
|
# finish by enforcing the pawn constraint (irrelevant for validation, the
|
|
|
|
# corresponding values in an encoded position is always zero)
|
|
|
|
psqt[["P"]][c(1, 8), ] <- 0
|
|
|
|
|
|
|
|
### Validation by GAM fitted on reduced data
|
|
|
|
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
|
|
|
|
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
|
|
|
|
summary(fit.gam)
|
|
|
|
|
|
|
|
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
|
|
|
|
rmse.base <- sqrt(mean((mean(y) - y)^2))
|
|
|
|
y.hce <- Rchess::HCE(fens)
|
|
|
|
rmse.hce <- sqrt(mean((y.hce - y)^2))
|
|
|
|
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
|
|
|
|
rmse.hat <- sqrt(mean((y.hat - y)^2))
|
|
|
|
|
|
|
|
# Also extract R^2 (eval by hand or get from models)
|
|
|
|
(r.sq.lm <- summary(fit)$r.squared)
|
|
|
|
(r.sq.gam <- summary(fit.gam)$r.sq)
|
|
|
|
(r.sq.hce <- 1 - (rmse.hce / rmse.base)^2)
|
2023-12-19 13:13:49 +00:00
|
|
|
|
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
################################################################################
|
|
|
|
### Generate LaTeX PSQT plot ###
|
|
|
|
################################################################################
|
|
|
|
if (FALSE) {
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
sink("psqt.tex")
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
cat("% Authomatically generated by `dataAnalysis/chess.R`
|
|
|
|
\\documentclass{standalone}
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\usepackage[LSB, T1]{fontenc}
|
|
|
|
\\usepackage{chessboard}
|
|
|
|
\\usepackage{skak}
|
|
|
|
\\usepackage{tikz}
|
|
|
|
\\usepackage{amsmath}
|
|
|
|
\\usepackage{xcolor}
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\setboardfontencoding{LSB}
|
2023-12-13 17:46:17 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
|
2023-12-13 17:46:17 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
")
|
2023-12-13 17:46:17 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
cat(paste0("\\definecolor{col", 1:128, "}{HTML}{",
|
|
|
|
mapply(`[`, strsplit(hcl.colors(128, "Blue-Red 3", rev = TRUE), "#"), 2),
|
|
|
|
"}"
|
|
|
|
))
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
cat("
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\begin{document}
|
|
|
|
\\begin{tikzpicture}
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\coordinate (pawn) at (0, 0);
|
|
|
|
\\coordinate (knight) at (5, 0);
|
|
|
|
\\coordinate (bishop) at (10, 0);
|
|
|
|
\\coordinate (rook) at (0, -5.2);
|
|
|
|
\\coordinate (queen) at (5, -5.2);
|
|
|
|
\\coordinate (king) at (10, -5.2);
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
")
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
local({
|
|
|
|
zlim <- c(-1, 1) * max(abs(unlist(psqt, use.names = FALSE)))
|
|
|
|
breaks <- seq(zlim[1], zlim[2], len = 129)
|
|
|
|
|
|
|
|
pieces <- c("pawn", "knight", "bishop", "rook", "queen", "king")
|
|
|
|
for (i in seq_along(psqt)) {
|
|
|
|
cat(paste0("\\node (", pieces[i], ") at (", pieces[i], ") {\\chessboard[", paste0(
|
|
|
|
"color=col", as.integer(cut(psqt[[i]], breaks)),
|
|
|
|
",colorbackfield=", outer(8:1, letters[1:8], function(r, f) paste0(f, r)),
|
|
|
|
collapse=","
|
|
|
|
), "]};\n"))
|
|
|
|
}
|
|
|
|
})
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
cat("
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (pawn.north) {Pawn};
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (knight.north) {Knight};
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (bishop.north) {Bishop};
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (rook.north) {Rook};
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (queen.north) {Queen};
|
|
|
|
\\node[anchor = north, yshift = -0.4em] at (king.north) {King};
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
\\end{tikzpicture}
|
|
|
|
\\end{document}
|
|
|
|
")
|
2023-12-19 13:13:49 +00:00
|
|
|
|
2023-12-30 09:01:12 +00:00
|
|
|
sink()
|
2023-12-19 13:13:49 +00:00
|
|
|
|
|
|
|
}
|