43 lines
1.4 KiB
R
43 lines
1.4 KiB
R
|
|
#' Re-initialize model weights.
|
|
#'
|
|
#' An in-place model re-initialization. Intended for simulations to avoid
|
|
#' rebuilding the same model architecture for multiple simulation runs.
|
|
#'
|
|
#' @param model A `keras` model.
|
|
#'
|
|
#' @seealso https://github.com/keras-team/keras/issues/341
|
|
#' @examples
|
|
#' # library(tensorflow) # v2
|
|
#' K <- tf$keras
|
|
#' model <- K$models$Sequential(list(
|
|
#' K$layers$Dense(units = 7L, input_shape = list(3L)),
|
|
#' K$layers$Dense(units = 1L)
|
|
#' ))
|
|
#' model$compile(loss = 'MSE', optimizer = K$optimizers$RMSprop())
|
|
#'
|
|
#' model$weights
|
|
#' reinitialize_weights(model)
|
|
#' model$weights
|
|
#'
|
|
#' @export
|
|
reinitialize_weights <- function(model) {
|
|
for (layer in model$layers) {
|
|
# Unwrap wrapped layers.
|
|
if (any(endsWith(class(layer), 'Wrapper')))
|
|
layer <- layer$layer
|
|
# Re-initialize kernel and bias weight variables.
|
|
for (var in layer$weights) {
|
|
if (any(grep('/recurrent_kernel:', var$name, fixed = TRUE))) {
|
|
var$assign(layer$recurrent_initializer(var$shape, var$dtype))
|
|
} else if (any(grep('/kernel:', var$name, fixed = TRUE))) {
|
|
var$assign(layer$kernel_initializer(var$shape, var$dtype))
|
|
} else if (any(grep('/bias:', var$name, fixed = TRUE))) {
|
|
var$assign(layer$bias_initializer(var$shape, var$dtype))
|
|
} else {
|
|
stop("Unknown initialization for variable ", var$name)
|
|
}
|
|
}
|
|
}
|
|
}
|