tensor_predictors/dataAnalysis/chess/gmlm_chess.R

248 lines
10 KiB
R
Raw Normal View History

2023-12-11 13:29:51 +00:00
#' Specialized version of `gmlm_ising()`.
#'
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
#' gradient descent version of RMSprop instead of classic gradient descent.
#' Other differences are puerly of technical nature.
#'
#' @param data_gen data generator, samples from the data set conditioned on a
#' slice value `y.min` to `y.max`. Function signature
#' `function(batch.size, y.min, y.max)` with return value `X`, a
#' `8 x 8 x 12 x batch.size` 4D array.
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
#' @param score_breaks numeric vector of two or more unique cut points, the cut
#' points are the interval bounds specifying the slices of `y`.
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
#' every slice. The batch size of the for every iteration is then equal to
#' `slice_size * (length(score_breaks) - 1L)`.
#' @param max_iter maximum number of iterations for gradient optimization
#' @param patience integer, break condition parameter. If the approximated loss
#' doesn't improve over `patience` iterations, then stop.
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
#' the parameter update rule of RMSprop
#' @param save_point character, file name pattern for storing and retrieving
#' optimization save points. Those save points allow to stop the method and
#' resume optimization later from the last save point.
#'
gmlm_chess <- function(
data_gen,
fun_y = function(y) { `dim<-`(t(outer(y, c(0, 1, 1, 2), `^`)), c(2, 2, length(y))) },
2023-12-11 13:29:51 +00:00
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
nr_threads = 8L,
mcmc_samples = 10000L,
slice_size = 512L,
2023-12-19 13:13:49 +00:00
max_iter = 10000L,
patience = 10L,
step_size = 1e-2,
2023-12-11 13:29:51 +00:00
eps = sqrt(.Machine$double.eps),
2023-12-19 13:13:49 +00:00
save_point = "save_point_%s.Rdata"
2023-12-11 13:29:51 +00:00
) {
# build intervals from score break points
score_breaks <- sort(score_breaks)
score_min <- head(score_breaks, -1)
score_max <- tail(score_breaks, -1)
score_means <- (score_min + score_max) / 2
# Piece index lookup "table" by piece symbol
pieces <- `names<-`(1:12, unlist(strsplit("PNBRQKpnbrqk", "")))
# Build constraints for every mixture component, that is, for every piece
pawn_const <- which(as.logical(tcrossprod(.row(c(8, 8)) %in% c(1, 8))))
# King and Queen constraints (by queens its just an approx)
KQ_const <- which(!diag(64))
2023-12-11 13:29:51 +00:00
# Check if there is a save point (load from save)
load_point <- if (is.character(save_point)) {
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
} else {
character(0)
}
# It a load point is found, resume from save point, otherwise initialize
if (length(load_point)) {
load_point <- load_point[[1]]
cat(sprintf("Resuming from save point '%s'\n", load_point),
"(to restart delete/rename the save points)\n")
load(load_point)
2023-12-19 13:13:49 +00:00
# Fix `iter`, saved after increment
2023-12-11 13:29:51 +00:00
iter <- iter - 1L
} else {
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
F <- fun_y(rep(score_means, each = slice_size))
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
dimX <- c(8L, 8L) # for every mixture component
dimF <- dim(F)[1:2] # also per mixture component
2023-12-11 13:29:51 +00:00
# Initialize `betas` for every mixture component
betas <- Map(function(piece) {
gmlm_tensor_normal(X[, , piece, ], F)$betas
}, pieces)
2023-12-11 13:29:51 +00:00
# and initial values for `Omegas`, based on the same first "big" sample
Omegas <- Map(function(piece) {
X <- X[, , piece, ]
Map(function(mode) {
n <- prod(dim(X)[-mode])
prob2 <- mcrossprod(X, mode = mode) / n
prob2[prob2 == 0] <- 1 / n
prob2[prob2 == 1] <- (n - 1) / n
prob1 <- diag(prob2)
`prob1^2` <- outer(prob1, prob1)
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
}, 1:2)
}, pieces)
Omegas[[pieces["P"]]][[1]][c(1, 8), ] <- 0
Omegas[[pieces["p"]]][[1]][c(1, 8), ] <- 0
# Initial sample `(X, F)` no longer needed, remove them
2023-12-11 13:29:51 +00:00
rm(X, F)
# Initialize gradients and aggregated mean squared gradients
grad2_betas <- Map(function(params) Map(array, 0, Map(dim, params)), betas)
grad2_Omegas <- Map(function(params) Map(array, 0, Map(dim, params)), Omegas)
2023-12-11 13:29:51 +00:00
# initialize optimization tracker for break condition
last_loss <- best_loss <- Inf
2023-12-11 13:29:51 +00:00
non_improving <- 0L
iter <- 0L
}
# main optimization loop
while ((iter <- iter + 1L) <= max_iter) {
# At beginning of every iteration, store current state in a save point.
# This allows to resume optimization from the last save point.
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, best_loss, non_improving, iter,
2023-12-11 13:29:51 +00:00
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
}
# start timing for this iteration (this is precise enough)
start_time <- proc.time()[["elapsed"]]
# Full Omega(s) for every piece mixture component
Omega <- Map(function(Omegas) {
kronecker(Omegas[[2]], Omegas[[1]])
}, Omegas)
Omega[[pieces["P"]]][pawn_const] <- 0
Omega[[pieces["p"]]][pawn_const] <- 0
Omega[[pieces["K"]]][KQ_const] <- 0
Omega[[pieces["k"]]][KQ_const] <- 0
Omega[[pieces["Q"]]][KQ_const] <- 0
Omega[[pieces["q"]]][KQ_const] <- 0
2023-12-11 13:29:51 +00:00
# Gradient and negative log-likelihood approximation
loss <- 0 # neg. log-likelihood
grad_betas <- Map(function(piece) Map(matrix, 0, dimX, dimF), pieces) # grads for betas
R2 <- Map(function(piece) array(0, dim = c(dimX, dimX)), pieces) # residuals
2023-12-11 13:29:51 +00:00
# for every score slice
for (slice in seq_along(score_means)) {
2023-12-11 13:29:51 +00:00
# function of `y` being the score slice mean (only 3D, same for all obs.)
F <- `dim<-`(fun_y(score_means[slice]), dimF)
2023-12-11 13:29:51 +00:00
# compute parameters of (slice) conditional Ising model
params <- Map(function(betas, Omega) {
`diag<-`(Omega, as.vector(mlm(F, betas)))
}, betas, Omega)
# second moment of `X_{,,piece} | Y = score_means[slice]` for every piece
m2 <- Map(function(param) {
ising_m2(param, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
}, params)
2023-12-11 13:29:51 +00:00
# Draw a new sample
X <- data_gen(slice_size, score_min[slice], score_max[slice])
2023-12-11 13:29:51 +00:00
# Split into matricized mixture parts
matX <- Map(function(piece) {
`dim<-`(X[, , piece, ], c(64, slice_size))
}, pieces)
2023-12-11 13:29:51 +00:00
# accumulated loss over all piece mixtures
loss <- loss - Reduce(`+`, Map(function(matX, param, m2) {
sum(matX * (param %*% matX)) + slice_size * attr(m2, "log_prob_0")
}, matX, params, m2))
2023-12-11 13:29:51 +00:00
# Slice residuals (second order `resid2` and actual residuals `resid1`)
resid2 <- Map(function(matX, m2) {
tcrossprod(matX) - slice_size * m2
}, matX, m2)
2023-12-11 13:29:51 +00:00
# accumulate residuals
R2 <- Map(function(R2, resid2) { R2 + as.vector(resid2) }, R2, resid2)
2023-12-11 13:29:51 +00:00
# and the beta gradients
grad_betas <- Map(function(grad_betas, resid2, betas) {
resid1 <- `dim<-`(diag(resid2), dimX)
Map(`+`, grad_betas, Map(function(mode) {
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:2)[-mode]), mode)
}, 1:2))
}, grad_betas, resid2, betas)
2023-12-11 13:29:51 +00:00
}
# finaly, finish gradient computation with gradients for `Omegas`
grad_Omegas <- Map(function(R2, Omegas) {
Map(function(mode) {
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:2)[-mode], transposed = TRUE)
`dim<-`(grad, dim(Omegas[[mode]]))
}, 1:2)
}, R2, Omegas)
2023-12-11 13:29:51 +00:00
# Update tracker for break condition
non_improving <- if (best_loss < loss) non_improving + 1L else 0L
last_loss <- loss
best_loss <- min(best_loss, loss)
2023-12-11 13:29:51 +00:00
# check break condition
if (non_improving > patience) { break }
# accumulate root mean squared gradients
grad2_betas <- Map(function(grad2_betas, grad_betas) {
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
}, grad2_betas, grad_betas)
grad2_Omegas <- Map(function(grad2_Omegas, grad_Omegas) {
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
}, grad2_Omegas, grad_Omegas)
2023-12-11 13:29:51 +00:00
# Update Parameters
betas <- Map(function(betas, grad_betas, grad2_betas) {
Map(function(beta, grad, M2) {
beta + (step_size / (sqrt(M2) + eps)) * grad
}, betas, grad_betas, grad2_betas)
2023-12-11 13:29:51 +00:00
}, betas, grad_betas, grad2_betas)
Omegas <- Map(function(Omegas, grad_Omegas, grad2_Omegas) {
Map(function(Omega, grad, M2) {
Omega + (step_size / (sqrt(M2) + eps)) * grad
}, Omegas, grad_Omegas, grad2_Omegas)
2023-12-11 13:29:51 +00:00
}, Omegas, grad_Omegas, grad2_Omegas)
# Log progress
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f (best: %f, non-improving: %d)\n",
iter, round(proc.time()[["elapsed"]] - start_time), loss, best_loss, non_improving))
2023-12-11 13:29:51 +00:00
}
# Save a final (terminal) save point
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, best_loss, non_improving, iter,
2023-12-11 13:29:51 +00:00
file = sprintf(save_point, "final")))
}
structure(
list(betas = betas, Omegas = Omegas),
iter = iter, loss = loss
)
}