248 lines
10 KiB
R
248 lines
10 KiB
R
#' Specialized version of `gmlm_ising()`.
|
|
#'
|
|
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
|
|
#' gradient descent version of RMSprop instead of classic gradient descent.
|
|
#' Other differences are puerly of technical nature.
|
|
#'
|
|
#' @param data_gen data generator, samples from the data set conditioned on a
|
|
#' slice value `y.min` to `y.max`. Function signature
|
|
#' `function(batch.size, y.min, y.max)` with return value `X`, a
|
|
#' `8 x 8 x 12 x batch.size` 4D array.
|
|
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
|
|
#' @param score_breaks numeric vector of two or more unique cut points, the cut
|
|
#' points are the interval bounds specifying the slices of `y`.
|
|
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
|
|
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
|
|
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
|
|
#' every slice. The batch size of the for every iteration is then equal to
|
|
#' `slice_size * (length(score_breaks) - 1L)`.
|
|
#' @param max_iter maximum number of iterations for gradient optimization
|
|
#' @param patience integer, break condition parameter. If the approximated loss
|
|
#' doesn't improve over `patience` iterations, then stop.
|
|
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
|
|
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
|
|
#' the parameter update rule of RMSprop
|
|
#' @param save_point character, file name pattern for storing and retrieving
|
|
#' optimization save points. Those save points allow to stop the method and
|
|
#' resume optimization later from the last save point.
|
|
#'
|
|
gmlm_chess <- function(
|
|
data_gen,
|
|
fun_y = function(y) { `dim<-`(t(outer(y, c(0, 1, 1, 2), `^`)), c(2, 2, length(y))) },
|
|
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
|
|
nr_threads = 8L,
|
|
mcmc_samples = 10000L,
|
|
slice_size = 512L,
|
|
max_iter = 10000L,
|
|
patience = 10L,
|
|
step_size = 1e-2,
|
|
eps = sqrt(.Machine$double.eps),
|
|
save_point = "save_point_%s.Rdata"
|
|
) {
|
|
# build intervals from score break points
|
|
score_breaks <- sort(score_breaks)
|
|
score_min <- head(score_breaks, -1)
|
|
score_max <- tail(score_breaks, -1)
|
|
score_means <- (score_min + score_max) / 2
|
|
|
|
# Piece index lookup "table" by piece symbol
|
|
pieces <- `names<-`(1:12, unlist(strsplit("PNBRQKpnbrqk", "")))
|
|
|
|
# Build constraints for every mixture component, that is, for every piece
|
|
pawn_const <- which(as.logical(tcrossprod(.row(c(8, 8)) %in% c(1, 8))))
|
|
# King and Queen constraints (by queens its just an approx)
|
|
KQ_const <- which(!diag(64))
|
|
|
|
# Check if there is a save point (load from save)
|
|
load_point <- if (is.character(save_point)) {
|
|
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
|
|
} else {
|
|
character(0)
|
|
}
|
|
# It a load point is found, resume from save point, otherwise initialize
|
|
if (length(load_point)) {
|
|
load_point <- load_point[[1]]
|
|
cat(sprintf("Resuming from save point '%s'\n", load_point),
|
|
"(to restart delete/rename the save points)\n")
|
|
load(load_point)
|
|
# Fix `iter`, saved after increment
|
|
iter <- iter - 1L
|
|
} else {
|
|
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
|
|
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
|
|
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
|
|
F <- fun_y(rep(score_means, each = slice_size))
|
|
|
|
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
|
|
dimX <- c(8L, 8L) # for every mixture component
|
|
dimF <- dim(F)[1:2] # also per mixture component
|
|
|
|
# Initialize `betas` for every mixture component
|
|
betas <- Map(function(piece) {
|
|
gmlm_tensor_normal(X[, , piece, ], F)$betas
|
|
}, pieces)
|
|
|
|
# and initial values for `Omegas`, based on the same first "big" sample
|
|
Omegas <- Map(function(piece) {
|
|
X <- X[, , piece, ]
|
|
Map(function(mode) {
|
|
n <- prod(dim(X)[-mode])
|
|
prob2 <- mcrossprod(X, mode = mode) / n
|
|
prob2[prob2 == 0] <- 1 / n
|
|
prob2[prob2 == 1] <- (n - 1) / n
|
|
prob1 <- diag(prob2)
|
|
`prob1^2` <- outer(prob1, prob1)
|
|
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
|
|
}, 1:2)
|
|
}, pieces)
|
|
Omegas[[pieces["P"]]][[1]][c(1, 8), ] <- 0
|
|
Omegas[[pieces["p"]]][[1]][c(1, 8), ] <- 0
|
|
|
|
# Initial sample `(X, F)` no longer needed, remove them
|
|
rm(X, F)
|
|
|
|
# Initialize gradients and aggregated mean squared gradients
|
|
grad2_betas <- Map(function(params) Map(array, 0, Map(dim, params)), betas)
|
|
grad2_Omegas <- Map(function(params) Map(array, 0, Map(dim, params)), Omegas)
|
|
|
|
# initialize optimization tracker for break condition
|
|
last_loss <- best_loss <- Inf
|
|
non_improving <- 0L
|
|
iter <- 0L
|
|
}
|
|
|
|
# main optimization loop
|
|
while ((iter <- iter + 1L) <= max_iter) {
|
|
|
|
# At beginning of every iteration, store current state in a save point.
|
|
# This allows to resume optimization from the last save point.
|
|
if (is.character(save_point)) {
|
|
suspendInterrupts(save(
|
|
dimX, dimF,
|
|
betas, Omegas,
|
|
grad2_betas, grad2_Omegas,
|
|
last_loss, best_loss, non_improving, iter,
|
|
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
|
|
}
|
|
|
|
# start timing for this iteration (this is precise enough)
|
|
start_time <- proc.time()[["elapsed"]]
|
|
|
|
# Full Omega(s) for every piece mixture component
|
|
Omega <- Map(function(Omegas) {
|
|
kronecker(Omegas[[2]], Omegas[[1]])
|
|
}, Omegas)
|
|
Omega[[pieces["P"]]][pawn_const] <- 0
|
|
Omega[[pieces["p"]]][pawn_const] <- 0
|
|
Omega[[pieces["K"]]][KQ_const] <- 0
|
|
Omega[[pieces["k"]]][KQ_const] <- 0
|
|
Omega[[pieces["Q"]]][KQ_const] <- 0
|
|
Omega[[pieces["q"]]][KQ_const] <- 0
|
|
|
|
# Gradient and negative log-likelihood approximation
|
|
loss <- 0 # neg. log-likelihood
|
|
grad_betas <- Map(function(piece) Map(matrix, 0, dimX, dimF), pieces) # grads for betas
|
|
R2 <- Map(function(piece) array(0, dim = c(dimX, dimX)), pieces) # residuals
|
|
|
|
# for every score slice
|
|
for (slice in seq_along(score_means)) {
|
|
# function of `y` being the score slice mean (only 3D, same for all obs.)
|
|
F <- `dim<-`(fun_y(score_means[slice]), dimF)
|
|
|
|
# compute parameters of (slice) conditional Ising model
|
|
params <- Map(function(betas, Omega) {
|
|
`diag<-`(Omega, as.vector(mlm(F, betas)))
|
|
}, betas, Omega)
|
|
|
|
# second moment of `X_{,,piece} | Y = score_means[slice]` for every piece
|
|
m2 <- Map(function(param) {
|
|
ising_m2(param, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
|
|
}, params)
|
|
|
|
# Draw a new sample
|
|
X <- data_gen(slice_size, score_min[slice], score_max[slice])
|
|
|
|
# Split into matricized mixture parts
|
|
matX <- Map(function(piece) {
|
|
`dim<-`(X[, , piece, ], c(64, slice_size))
|
|
}, pieces)
|
|
|
|
# accumulated loss over all piece mixtures
|
|
loss <- loss - Reduce(`+`, Map(function(matX, param, m2) {
|
|
sum(matX * (param %*% matX)) + slice_size * attr(m2, "log_prob_0")
|
|
}, matX, params, m2))
|
|
|
|
# Slice residuals (second order `resid2` and actual residuals `resid1`)
|
|
resid2 <- Map(function(matX, m2) {
|
|
tcrossprod(matX) - slice_size * m2
|
|
}, matX, m2)
|
|
|
|
# accumulate residuals
|
|
R2 <- Map(function(R2, resid2) { R2 + as.vector(resid2) }, R2, resid2)
|
|
|
|
# and the beta gradients
|
|
grad_betas <- Map(function(grad_betas, resid2, betas) {
|
|
resid1 <- `dim<-`(diag(resid2), dimX)
|
|
Map(`+`, grad_betas, Map(function(mode) {
|
|
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:2)[-mode]), mode)
|
|
}, 1:2))
|
|
}, grad_betas, resid2, betas)
|
|
}
|
|
|
|
# finaly, finish gradient computation with gradients for `Omegas`
|
|
grad_Omegas <- Map(function(R2, Omegas) {
|
|
Map(function(mode) {
|
|
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:2)[-mode], transposed = TRUE)
|
|
`dim<-`(grad, dim(Omegas[[mode]]))
|
|
}, 1:2)
|
|
}, R2, Omegas)
|
|
|
|
# Update tracker for break condition
|
|
non_improving <- if (best_loss < loss) non_improving + 1L else 0L
|
|
last_loss <- loss
|
|
best_loss <- min(best_loss, loss)
|
|
|
|
# check break condition
|
|
if (non_improving > patience) { break }
|
|
|
|
# accumulate root mean squared gradients
|
|
grad2_betas <- Map(function(grad2_betas, grad_betas) {
|
|
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
|
|
}, grad2_betas, grad_betas)
|
|
grad2_Omegas <- Map(function(grad2_Omegas, grad_Omegas) {
|
|
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
|
|
}, grad2_Omegas, grad_Omegas)
|
|
|
|
# Update Parameters
|
|
betas <- Map(function(betas, grad_betas, grad2_betas) {
|
|
Map(function(beta, grad, M2) {
|
|
beta + (step_size / (sqrt(M2) + eps)) * grad
|
|
}, betas, grad_betas, grad2_betas)
|
|
}, betas, grad_betas, grad2_betas)
|
|
Omegas <- Map(function(Omegas, grad_Omegas, grad2_Omegas) {
|
|
Map(function(Omega, grad, M2) {
|
|
Omega + (step_size / (sqrt(M2) + eps)) * grad
|
|
}, Omegas, grad_Omegas, grad2_Omegas)
|
|
}, Omegas, grad_Omegas, grad2_Omegas)
|
|
|
|
# Log progress
|
|
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f (best: %f, non-improving: %d)\n",
|
|
iter, round(proc.time()[["elapsed"]] - start_time), loss, best_loss, non_improving))
|
|
}
|
|
|
|
# Save a final (terminal) save point
|
|
if (is.character(save_point)) {
|
|
suspendInterrupts(save(
|
|
dimX, dimF,
|
|
betas, Omegas,
|
|
grad2_betas, grad2_Omegas,
|
|
last_loss, best_loss, non_improving, iter,
|
|
file = sprintf(save_point, "final")))
|
|
}
|
|
|
|
structure(
|
|
list(betas = betas, Omegas = Omegas),
|
|
iter = iter, loss = loss
|
|
)
|
|
}
|