tensor_predictors/dataAnalysis/chess/gmlm_chess.R

224 lines
9.5 KiB
R

#' Specialized version of `gmlm_ising()`.
#'
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
#' gradient descent version of RMSprop instead of classic gradient descent.
#' Other differences are puerly of technical nature.
#'
#' @param data_gen data generator, samples from the data set conditioned on a
#' slice value `y.min` to `y.max`. Function signature
#' `function(batch.size, y.min, y.max)` with return value `X`, a
#' `8 x 8 x 12 x batch.size` 4D array.
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
#' @param score_breaks numeric vector of two or more unique cut points, the cut
#' points are the interval bounds specifying the slices of `y`.
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
#' every slice. The batch size of the for every iteration is then equal to
#' `slice_size * (length(score_breaks) - 1L)`.
#' @param max_iter maximum number of iterations for gradient optimization
#' @param patience integer, break condition parameter. If the approximated loss
#' doesn't improve over `patience` iterations, then stop.
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
#' the parameter update rule of RMSprop
#' @param save_point character, file name pattern for storing and retrieving
#' optimization save points. Those save points allow to stop the method and
#' resume optimization later from the last save point.
#'
gmlm_chess <- function(
data_gen,
fun_y,
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
nr_threads = 8L,
mcmc_samples = 10000L,
slice_size = 512L,
max_iter = 10000L,
patience = 25L,
step_size = 1e-3,
eps = sqrt(.Machine$double.eps),
save_point = "save_point_%s.Rdata"
) {
# build intervals from score break points
score_breaks <- sort(score_breaks)
score_min <- head(score_breaks, -1)
score_max <- tail(score_breaks, -1)
score_means <- (score_min + score_max) / 2
# build Omega constraint, that is the set of impossible combinations
# (including self interactions) due to the rules of chess
Omega_const <- local({
# One piece per square
diag_offset <- abs(.row(c(768, 768)) - .col(c(768, 768)))
Omega_const <- !diag(768) & ((diag_offset %% 64L) == 0L)
# One King per color
Omega_const <- Omega_const | kronecker(diag(1:12 %in% c(6, 12)), !diag(64), `&`)
# Enemy kings can _not_ be on neightbouring squares
king_const <- mapply(function(i, j) {
`[<-`((abs(.row(c(8, 8)) - i) <= 1L) & (abs(.col(c(8, 8)) - j) <= 1L), i, j, FALSE)
}, .row(c(8, 8)), .col(c(8, 8)))
dim(Omega_const) <- c(64, 12, 64, 12)
Omega_const[, 6, , 12] <- Omega_const[, 6, , 12] | king_const
Omega_const[, 12, , 6] <- Omega_const[, 12, , 6] | king_const
dim(Omega_const) <- c(768, 768)
# no pawns on rank 1 or rank 8
pawn_const <- tcrossprod(as.vector(`[<-`(matrix(0L, 8, 8), c(1, 8), , 1L)), rep(1L, 64))
pawn_const <- kronecker(`[<-`(matrix(0, 12, 12), c(1, 7), , 1), pawn_const)
which(Omega_const | (pawn_const | t(pawn_const)))
})
# Check if there is a save point (load from save)
load_point <- if (is.character(save_point)) {
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
} else {
character(0)
}
# It a load point is found, resume from save point, otherwise initialize
if (length(load_point)) {
load_point <- load_point[[1]]
cat(sprintf("Resuming from save point '%s'\n", load_point),
"(to restart delete/rename the save points)\n")
load(load_point)
# Fix `iter`, saved after increment
iter <- iter - 1L
} else {
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
F <- fun_y(rep(score_means, each = slice_size))
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
dimX <- c(8L, 8L, 12L)
dimF <- dim(F)[1:3]
# Initial values for `betas` are the tensor normal GMLM estimates
betas <- gmlm_tensor_normal(X, F)$betas
# and initial values for `Omegas`, based on the same first "big" sample
Omegas <- Map(function(mode) {
n <- prod(dim(X)[-mode])
prob2 <- mcrossprod(X, mode = mode) / n
prob2[prob2 == 0] <- 1 / n
prob2[prob2 == 1] <- (n - 1) / n
prob1 <- diag(prob2)
`prob1^2` <- outer(prob1, prob1)
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
}, 1:3)
# Initial sample `(X, F)` no longer needed, remove
rm(X, F)
# Initialize gradients and aggregated mean squared gradients
grad2_betas <- Map(array, 0, Map(dim, betas))
grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
# initialize optimization tracker for break condition
last_loss <- Inf
non_improving <- 0L
iter <- 0L
}
# main optimization loop
while ((iter <- iter + 1L) <= max_iter) {
# At beginning of every iteration, store current state in a save point.
# This allows to resume optimization from the last save point.
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, non_improving, iter,
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
}
# start timing for this iteration (this is precise enough)
start_time <- proc.time()[["elapsed"]]
# full Omega (with constraint elements set to zero) needed to conditional
# parameters of the Ising model to compute (approx) the second moment
Omega <- `[<-`(Reduce(kronecker, rev(Omegas)), Omega_const, 0)
# Gradient and negative log-likelihood approximation
loss <- 0 # neg. log-likelihood
grad_betas <- Map(matrix, 0, dimX, dimF) # grads for betas
R2 <- array(0, dim = c(dimX, dimX)) # residuals
# for every score slice
for (i in seq_along(score_means)) {
# function of `y` being the score slice mean (only 3D, same for all obs.)
F <- `dim<-`(fun_y(score_means[i]), dimF)
# compute parameters of (slice) conditional Ising model
params <- `diag<-`(Omega, as.vector(mlm(F, betas)))
# second moment of `X | Y = score_means[i]`
m2 <- ising_m2(params, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
# draw random sample from current slice `vec(X) | Y in (score_min, score_max]`
# with columns being the vectorized observations `vec(X)`.
matX <- `dim<-`(data_gen(slice_size, score_min[i], score_max[i]), c(prod(dimX), slice_size))
# accumulate (approx) neg. log-likelihood
loss <- loss - (sum(matX * (params %*% matX)) + slice_size * attr(m2, "log_prob_0"))
# Slice residuals (second order `resid2` and actual residuals `resid1`)
resid2 <- tcrossprod(matX) - slice_size * m2
resid1 <- `dim<-`(diag(resid2), dimX)
# accumulate residuals
R2 <- R2 + as.vector(resid2)
# and the beta gradients
grad_betas <- Map(`+`, grad_betas, Map(function(mode) {
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:3)[-mode]), mode)
}, 1:3))
}
# finaly, finish gradient computation with gradients for `Omegas`
grad_Omegas <- Map(function(mode) {
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:3)[-mode], transposed = TRUE)
`dim<-`(grad, dim(Omegas[[mode]]))
}, 1:3)
# Update tracker for break condition
non_improving <- max(0L, non_improving - 1L + 2L * (last_loss < loss))
last_loss <- loss
# check break condition
if (non_improving > patience) { break }
# accumulate root mean squared gradients
grad2_betas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
grad2_Omegas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
# Update Parameters
betas <- Map(function(beta, grad, m2) {
beta + (step_size / (sqrt(m2) + eps)) * grad
}, betas, grad_betas, grad2_betas)
Omegas <- Map(function(Omega, grad, m2) {
Omega + (step_size / (sqrt(m2) + eps)) * grad
}, Omegas, grad_Omegas, grad2_Omegas)
# Log progress
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f\n",
iter, round(proc.time()[["elapsed"]] - start_time), loss))
}
# Save a final (terminal) save point
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, non_improving, iter,
file = sprintf(save_point, "final")))
}
structure(
list(betas = betas, Omegas = Omegas),
iter = iter, loss = loss
)
}