216 lines
9.0 KiB
R
216 lines
9.0 KiB
R
#' Specialized version of `gmlm_ising()`.
|
|
#'
|
|
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
|
|
#' gradient descent version of RMSprop instead of classic gradient descent.
|
|
#' Other differences are puerly of technical nature.
|
|
#'
|
|
#' @param data_gen data generator, samples from the data set conditioned on a
|
|
#' slice value `y.min` to `y.max`. Function signature
|
|
#' `function(batch.size, y.min, y.max)` with return value `X`, a
|
|
#' `8 x 8 x 12 x batch.size` 4D array.
|
|
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
|
|
#' @param score_breaks numeric vector of two or more unique cut points, the cut
|
|
#' points are the interval bounds specifying the slices of `y`.
|
|
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
|
|
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
|
|
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
|
|
#' every slice. The batch size of the for every iteration is then equal to
|
|
#' `slice_size * (length(score_breaks) - 1L)`.
|
|
#' @param max_iter maximum number of iterations for gradient optimization
|
|
#' @param patience integer, break condition parameter. If the approximated loss
|
|
#' doesn't improve over `patience` iterations, then stop.
|
|
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
|
|
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
|
|
#' the parameter update rule of RMSprop
|
|
#' @param save_point character, file name pattern for storing and retrieving
|
|
#' optimization save points. Those save points allow to stop the method and
|
|
#' resume optimization later from the last save point.
|
|
#'
|
|
gmlm_chess <- function(
|
|
data_gen,
|
|
fun_y,
|
|
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
|
|
nr_threads = 8L,
|
|
mcmc_samples = 10000L,
|
|
slice_size = 512L,
|
|
max_iter = 1000L,
|
|
patience = 25L,
|
|
step_size = 1e-3,
|
|
eps = sqrt(.Machine$double.eps),
|
|
save_point = "gmlm_chess_save_point_%s.Rdata"
|
|
) {
|
|
# build intervals from score break points
|
|
score_breaks <- sort(score_breaks)
|
|
score_min <- head(score_breaks, -1)
|
|
score_max <- tail(score_breaks, -1)
|
|
score_means <- (score_min + score_max) / 2
|
|
|
|
# build Omega constraint, that is the set of impossible combinations
|
|
# (including self interactions) due to the rules of chess
|
|
Omega_const <- local({
|
|
# One piece per square
|
|
diag_offset <- abs(.row(c(768, 768)) - .col(c(768, 768)))
|
|
Omega_const <- !diag(768) & ((diag_offset %% 64L) == 0L)
|
|
# One King per color
|
|
Omega_const <- Omega_const | kronecker(diag(1:12 %in% c(6, 12)), !diag(64), `&`)
|
|
# no pawns on rank 1 or rank 8
|
|
pawn_const <- tcrossprod(as.vector(`[<-`(matrix(0L, 8, 8), c(1, 8), , 1L)), rep(1L, 64))
|
|
pawn_const <- kronecker(`[<-`(matrix(0, 12, 12), c(1, 7), , 1), pawn_const)
|
|
which(Omega_const | (pawn_const | t(pawn_const)))
|
|
})
|
|
|
|
# Check if there is a save point (load from save)
|
|
load_point <- if (is.character(save_point)) {
|
|
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
|
|
} else {
|
|
character(0)
|
|
}
|
|
# It a load point is found, resume from save point, otherwise initialize
|
|
if (length(load_point)) {
|
|
load_point <- load_point[[1]]
|
|
cat(sprintf("Resuming from save point '%s'\n", load_point),
|
|
"(to restart delete/rename the save points)\n")
|
|
load(load_point)
|
|
# Fix `iter`, save after increment
|
|
iter <- iter - 1L
|
|
} else {
|
|
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
|
|
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
|
|
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
|
|
F <- fun_y(rep(score_means, each = slice_size))
|
|
|
|
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
|
|
dimX <- c(8L, 8L, 12L)
|
|
dimF <- dim(F)[1:3]
|
|
|
|
# Initial values for `betas` are the tensor normal GMLM estimates
|
|
betas <- gmlm_tensor_normal(X, F)$betas
|
|
|
|
# and initial values for `Omegas`, based on the same first "big" sample
|
|
Omegas <- Map(function(mode) {
|
|
n <- prod(dim(X)[-mode])
|
|
prob2 <- mcrossprod(X, mode = mode) / n
|
|
prob2[prob2 == 0] <- 1 / n
|
|
prob2[prob2 == 1] <- (n - 1) / n
|
|
prob1 <- diag(prob2)
|
|
`prob1^2` <- outer(prob1, prob1)
|
|
|
|
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
|
|
}, 1:3)
|
|
|
|
# Initial sample `(X, F)` no longer needed, remove
|
|
rm(X, F)
|
|
|
|
# Initialize gradients and aggregated mean squared gradients
|
|
grad2_betas <- Map(array, 0, Map(dim, betas))
|
|
grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
|
|
|
|
# initialize optimization tracker for break condition
|
|
last_loss <- Inf
|
|
non_improving <- 0L
|
|
iter <- 0L
|
|
}
|
|
|
|
# main optimization loop
|
|
while ((iter <- iter + 1L) <= max_iter) {
|
|
|
|
# At beginning of every iteration, store current state in a save point.
|
|
# This allows to resume optimization from the last save point.
|
|
if (is.character(save_point)) {
|
|
suspendInterrupts(save(
|
|
dimX, dimF,
|
|
betas, Omegas,
|
|
grad2_betas, grad2_Omegas,
|
|
last_loss, non_improving, iter,
|
|
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
|
|
}
|
|
|
|
# start timing for this iteration (this is precise enough)
|
|
start_time <- proc.time()[["elapsed"]]
|
|
|
|
# full Omega (with constraint elements set to zero) needed to conditional
|
|
# parameters of the Ising model to compute (approx) the second moment
|
|
Omega <- `[<-`(Reduce(kronecker, rev(Omegas)), Omega_const, 0)
|
|
|
|
# Gradient and negative log-likelihood approximation
|
|
loss <- 0 # neg. log-likelihood
|
|
grad_betas <- Map(matrix, 0, dimX, dimF) # grads for betas
|
|
R2 <- array(0, dim = c(dimX, dimX)) # residuals
|
|
|
|
# for every score slice
|
|
for (i in seq_along(score_means)) {
|
|
# function of `y` being the score slice mean (only 3D, same for all obs.)
|
|
F <- `dim<-`(fun_y(score_means[i]), dimF)
|
|
|
|
# compute parameters of (slice) conditional Ising model
|
|
params <- `diag<-`(Omega, as.vector(mlm(F, betas)))
|
|
|
|
# second moment of `X | Y = score_means[i]`
|
|
m2 <- ising_m2(params, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
|
|
|
|
# draw random sample from current slice `vec(X) | Y in (score_min, score_max]`
|
|
# with columns being the vectorized observations `vec(X)`.
|
|
matX <- `dim<-`(data_gen(slice_size, score_min[i], score_max[i]), c(prod(dimX), slice_size))
|
|
|
|
# accumulate (approx) neg. log-likelihood
|
|
loss <- loss - (sum(matX * (params %*% matX)) + slice_size * attr(m2, "log_prob_0"))
|
|
|
|
# Slice residuals (second order `resid2` and actual residuals `resid1`)
|
|
resid2 <- tcrossprod(matX) - slice_size * m2
|
|
resid1 <- `dim<-`(diag(resid2), dimX)
|
|
|
|
# accumulate residuals
|
|
R2 <- R2 + as.vector(resid2)
|
|
|
|
# and the beta gradients
|
|
grad_betas <- Map(`+`, grad_betas, Map(function(mode) {
|
|
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:3)[-mode]), mode)
|
|
}, 1:3))
|
|
}
|
|
|
|
# finaly, finish gradient computation with gradients for `Omegas`
|
|
grad_Omegas <- Map(function(mode) {
|
|
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:3)[-mode], transposed = TRUE)
|
|
`dim<-`(grad, dim(Omegas[[mode]]))
|
|
}, 1:3)
|
|
|
|
# Update tracker for break condition
|
|
non_improving <- max(0L, non_improving - 1L + 2L * (last_loss < loss))
|
|
loss_last <- loss
|
|
|
|
# check break condition
|
|
if (non_improving > patience) { break }
|
|
|
|
# accumulate root mean squared gradients
|
|
grad2_betas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
|
|
grad2_Omegas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
|
|
|
|
# Update Parameters
|
|
betas <- Map(function(beta, grad, m2) {
|
|
beta + (step_size / (sqrt(m2) + eps)) * grad
|
|
}, betas, grad_betas, grad2_betas)
|
|
Omegas <- Map(function(Omega, grad, m2) {
|
|
Omega + (step_size / (sqrt(m2) + eps)) * grad
|
|
}, Omegas, grad_Omegas, grad2_Omegas)
|
|
|
|
# Log progress
|
|
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f\n",
|
|
iter, round(proc.time()[["elapsed"]] - start_time), loss))
|
|
}
|
|
|
|
# Save a final (terminal) save point
|
|
if (is.character(save_point)) {
|
|
suspendInterrupts(save(
|
|
dimX, dimF,
|
|
betas, Omegas,
|
|
grad2_betas, grad2_Omegas,
|
|
last_loss, non_improving, iter,
|
|
file = sprintf(save_point, "final")))
|
|
}
|
|
|
|
structure(
|
|
list(betas = betas, Omegas = Omegas),
|
|
iter = iter, loss = loss
|
|
)
|
|
}
|