NNSDR/NNSDR/R/reinitialize_weights.R

43 lines
1.4 KiB
R

#' Re-initialize model weights.
#'
#' An in-place model re-initialization. Intended for simulations to avoid
#' rebuilding the same model architecture for multiple simulation runs.
#'
#' @param model A `keras` model.
#'
#' @seealso https://github.com/keras-team/keras/issues/341
#' @examples
#' # library(tensorflow) # v2
#' K <- tf$keras
#' model <- K$models$Sequential(list(
#' K$layers$Dense(units = 7L, input_shape = list(3L)),
#' K$layers$Dense(units = 1L)
#' ))
#' model$compile(loss = 'MSE', optimizer = K$optimizers$RMSprop())
#'
#' model$weights
#' reinitialize_weights(model)
#' model$weights
#'
#' @export
reinitialize_weights <- function(model) {
for (layer in model$layers) {
# Unwrap wrapped layers.
if (any(endsWith(class(layer), 'Wrapper')))
layer <- layer$layer
# Re-initialize kernel and bias weight variables.
for (var in layer$weights) {
if (any(grep('/recurrent_kernel:', var$name, fixed = TRUE))) {
var$assign(layer$recurrent_initializer(var$shape, var$dtype))
} else if (any(grep('/kernel:', var$name, fixed = TRUE))) {
var$assign(layer$kernel_initializer(var$shape, var$dtype))
} else if (any(grep('/bias:', var$name, fixed = TRUE))) {
var$assign(layer$bias_initializer(var$shape, var$dtype))
} else {
stop("Unknown initialization for variable ", var$name)
}
}
}
}