183 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			183 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
library(tensorPredictors)
 | 
						|
library(Rchess)
 | 
						|
 | 
						|
source("./gmlm_chess.R")
 | 
						|
 | 
						|
################################################################################
 | 
						|
###                      Fitting the GMLM mixture model                      ###
 | 
						|
################################################################################
 | 
						|
 | 
						|
# Data set file name of chess positions with Stockfish [https://stockfishchess.org]
 | 
						|
# evaluation scores (downloaded and processed by `./preprocessing.sh` from the
 | 
						|
# lichess data base [https://database.lichess.org/])
 | 
						|
data_set <- "lichess_db_standard_rated_2023-11.fen"
 | 
						|
 | 
						|
# Function to draw samples `X` form the chess position `data_set` conditioned on
 | 
						|
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
 | 
						|
data_gen <- function(batch_size, score_min, score_max) {
 | 
						|
    data <- Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE)
 | 
						|
    pos <- Rchess::fen2int(data$fens)
 | 
						|
    structure(pos, scores = data$scores)
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
# Invoke specialized GMLM optimization routine for chess data
 | 
						|
fit.gmlm <- gmlm_chess(data_gen)
 | 
						|
 | 
						|
 | 
						|
################################################################################
 | 
						|
###                 Reduction Interpretation and Validation                  ###
 | 
						|
################################################################################
 | 
						|
library(mgcv)       # for `gam()` (Generalized Additive Model)
 | 
						|
 | 
						|
# load last save point (includes reduction as `betas`)
 | 
						|
save_point <- sort(list.files(
 | 
						|
    ".",
 | 
						|
    pattern = "save_point_[0-9]*\\.Rdata",
 | 
						|
    full.names = TRUE
 | 
						|
), decreasing = TRUE)[[1]]
 | 
						|
load(save_point)
 | 
						|
 | 
						|
 | 
						|
### Construct PSQT (Piece SQuare Tables) from reduction `betas`
 | 
						|
sample_size <- 100000
 | 
						|
# Sample a new position data set for fitting a linear model to conbine different
 | 
						|
# reduction directions into a per piece PSQT matrix
 | 
						|
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
 | 
						|
# extract stockfish (non-static) position evaluation
 | 
						|
y <- attr(fens, "scores")
 | 
						|
# remove poitions with exact draw evalualtion
 | 
						|
draws <- which(y == 0.0)
 | 
						|
y <- y[-draws]
 | 
						|
fens <- fens[-draws]
 | 
						|
 | 
						|
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
 | 
						|
X <- Rchess::fen2int(fens)
 | 
						|
# Compute reduction
 | 
						|
reducedX <- Reduce(rbind, Map(function(piece) {
 | 
						|
    # "condition" on piece, that is to extract the current mixture component
 | 
						|
    X <- X[, , piece, ]
 | 
						|
    # reduce mixture component
 | 
						|
    mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
 | 
						|
}, 1:12))
 | 
						|
# Convert memory layout to contain vectorized observations in rows
 | 
						|
reducedX <- t(`dim<-`(reducedX, c(48, length(y))))
 | 
						|
# set names for coefficient extraction from linear fit
 | 
						|
colnames(reducedX) <- as.vector(outer(
 | 
						|
    unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
 | 
						|
))
 | 
						|
 | 
						|
# Estimate PSQT linear combination weights from reduced sample (exclude dead
 | 
						|
# draw positions, that is "score = 0". This are approx 5% of all positions)
 | 
						|
fit <- lm(y ~ ., data = data.frame(y = y, reducedX), subset = y != 0.0)
 | 
						|
summary(fit)
 | 
						|
# Translate reduction with weighting estimate into PSQTs
 | 
						|
psqt <- Map(function(piece) {
 | 
						|
    # reduction column names corresponding to the current white piece (upper case)
 | 
						|
    piece <- toupper(piece)
 | 
						|
    col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
 | 
						|
    # Whites PSQT
 | 
						|
    psqt_white <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
 | 
						|
    dim(psqt_white) <- c(8, 8)
 | 
						|
    # the same for black
 | 
						|
    piece <- tolower(piece)
 | 
						|
    col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
 | 
						|
    psqt_black <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
 | 
						|
    dim(psqt_black) <- c(8, 8)
 | 
						|
    # Combine into shared PSQT from whites point of view
 | 
						|
    psqt_white - psqt_black[8:1, ]
 | 
						|
}, c("P", "N", "B", "R", "Q", "K"))
 | 
						|
# finish by enforcing the pawn constraint (irrelevant for validation, the
 | 
						|
# corresponding values in an encoded position is always zero)
 | 
						|
psqt[["P"]][c(1, 8), ] <- 0
 | 
						|
 | 
						|
### Validation by GAM fitted on reduced data
 | 
						|
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
 | 
						|
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX))
 | 
						|
summary(fit.gam)
 | 
						|
 | 
						|
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
 | 
						|
(rmse.base <- sqrt(mean((mean(y) - y)^2)))
 | 
						|
y.hce <- Rchess::HCE(fens)
 | 
						|
(rmse.hce <- sqrt(mean((y.hce - y)^2)))
 | 
						|
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
 | 
						|
(rmse.hat <- sqrt(mean((y.hat - y)^2)))
 | 
						|
 | 
						|
# Also extract R^2 (eval by hand or get from models)
 | 
						|
(r.sq.lm <- summary(fit)$r.squared)
 | 
						|
(r.sq.gam <- summary(fit.gam)$r.sq)
 | 
						|
(r.sq.hce <- 1 - (rmse.hce / rmse.base)^2)
 | 
						|
 | 
						|
 | 
						|
################################################################################
 | 
						|
###                         Generate LaTeX PSQT plot                         ###
 | 
						|
################################################################################
 | 
						|
if (FALSE) {
 | 
						|
 | 
						|
sink("psqt.tex")
 | 
						|
 | 
						|
cat("% Authomatically generated by `dataAnalysis/chess.R`
 | 
						|
\\documentclass{standalone}
 | 
						|
 | 
						|
\\usepackage[LSB, T1]{fontenc}
 | 
						|
\\usepackage{chessboard}
 | 
						|
\\usepackage{skak}
 | 
						|
\\usepackage{tikz}
 | 
						|
\\usepackage{amsmath}
 | 
						|
\\usepackage{xcolor}
 | 
						|
 | 
						|
\\setboardfontencoding{LSB}
 | 
						|
 | 
						|
\\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
 | 
						|
 | 
						|
")
 | 
						|
 | 
						|
cat(paste0("\\definecolor{col", 1:128, "}{HTML}{",
 | 
						|
    mapply(`[`, strsplit(hcl.colors(128, "Blue-Red 3", rev = TRUE), "#"), 2),
 | 
						|
    "}"
 | 
						|
))
 | 
						|
 | 
						|
cat("
 | 
						|
 | 
						|
\\begin{document}
 | 
						|
\\begin{tikzpicture}
 | 
						|
 | 
						|
\\coordinate (pawn)   at (0, 0);
 | 
						|
\\coordinate (knight) at (5, 0);
 | 
						|
\\coordinate (bishop) at (10, 0);
 | 
						|
\\coordinate (rook)   at (0, -5.2);
 | 
						|
\\coordinate (queen)  at (5, -5.2);
 | 
						|
\\coordinate (king)   at (10, -5.2);
 | 
						|
 | 
						|
")
 | 
						|
 | 
						|
local({
 | 
						|
    zlim <- c(-1, 1) * max(abs(unlist(psqt, use.names = FALSE)))
 | 
						|
    breaks <- seq(zlim[1], zlim[2], len = 129)
 | 
						|
 | 
						|
    pieces <- c("pawn", "knight", "bishop", "rook", "queen", "king")
 | 
						|
    for (i in seq_along(psqt)) {
 | 
						|
        cat(paste0("\\node (", pieces[i], ") at (", pieces[i], ") {\\chessboard[", paste0(
 | 
						|
            "color=col", as.integer(cut(psqt[[i]], breaks)),
 | 
						|
            ",colorbackfield=", outer(8:1, letters[1:8], function(r, f) paste0(f, r)),
 | 
						|
            collapse=","
 | 
						|
        ), "]};\n"))
 | 
						|
    }
 | 
						|
})
 | 
						|
 | 
						|
cat("
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (pawn.north)   {Pawn};
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (knight.north) {Knight};
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (bishop.north) {Bishop};
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (rook.north)   {Rook};
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (queen.north)  {Queen};
 | 
						|
\\node[anchor = north, yshift = -0.4em] at (king.north)   {King};
 | 
						|
 | 
						|
\\end{tikzpicture}
 | 
						|
\\end{document}
 | 
						|
")
 | 
						|
 | 
						|
sink()
 | 
						|
 | 
						|
}
 |