280 lines
9.6 KiB
R
280 lines
9.6 KiB
R
|
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = "3")
|
||
|
|
||
|
|
||
|
|
||
|
#' Build MLP
|
||
|
#'
|
||
|
#' @param input_shapes TODO:
|
||
|
#' @param d TODO:
|
||
|
#' @param name TODO:
|
||
|
#' @param add_reduction TODO:
|
||
|
#' @param hidden_units TODO:
|
||
|
#' @param activation TODO:
|
||
|
#' @param dropout TODO:
|
||
|
#' @param loss TODO:
|
||
|
#' @param optimizer TODO:
|
||
|
#' @param metrics TODO:
|
||
|
#' @param trainable_reduction TODO:
|
||
|
#'
|
||
|
#' @import tensorflow
|
||
|
#' @keywords internal
|
||
|
build.MLP <- function(input_shapes, d, name, add_reduction,
|
||
|
output_shape = 1L,
|
||
|
hidden_units = 512L,
|
||
|
activation = 'relu',
|
||
|
dropout = 0.4,
|
||
|
loss = 'MSE',
|
||
|
optimizer = 'RMSProp',
|
||
|
metrics = NULL,
|
||
|
trainable_reduction = TRUE
|
||
|
) {
|
||
|
K <- tf$keras
|
||
|
|
||
|
inputs <- Map(K$layers$Input,
|
||
|
shape = as.integer(input_shapes), # drops names (concatenate key error)
|
||
|
name = if (is.null(names(input_shapes))) "" else names(input_shapes)
|
||
|
)
|
||
|
|
||
|
mlp_inputs <- if (add_reduction) {
|
||
|
reduction <- K$layers$Dense(
|
||
|
units = d,
|
||
|
use_bias = FALSE,
|
||
|
kernel_constraint = function(w) { # polar projection
|
||
|
lhs <- tf$linalg$sqrtm(tf$matmul(w, w, transpose_a = TRUE))
|
||
|
tf$transpose(tf$linalg$solve(lhs, tf$transpose(w)))
|
||
|
},
|
||
|
trainable = trainable_reduction,
|
||
|
name = 'reduction'
|
||
|
)(inputs[[1]])
|
||
|
|
||
|
c(reduction, inputs[-1])
|
||
|
} else {
|
||
|
inputs
|
||
|
}
|
||
|
|
||
|
out <- if (length(inputs) == 1) {
|
||
|
mlp_inputs[[1]]
|
||
|
} else {
|
||
|
K$layers$concatenate(mlp_inputs, axis = 1L, name = 'input_mlp')
|
||
|
}
|
||
|
for (i in seq_along(hidden_units)) {
|
||
|
out <- K$layers$Dense(units = hidden_units[i], activation = activation,
|
||
|
name = paste0('hidden', i))(out)
|
||
|
if (dropout > 0)
|
||
|
out <- K$layers$Dropout(rate = dropout, name = paste0('dropout', i))(out)
|
||
|
}
|
||
|
out <- K$layers$Dense(units = output_shape, name = 'output')(out)
|
||
|
|
||
|
mlp <- K$models$Model(inputs = inputs, outputs = out, name = name)
|
||
|
mlp$compile(loss = loss, optimizer = optimizer, metrics = metrics)
|
||
|
|
||
|
mlp
|
||
|
}
|
||
|
|
||
|
#' Base Neuronal Network model class
|
||
|
#'
|
||
|
#' @examples
|
||
|
#' model <- nnsdr$new(
|
||
|
#' input_shapes = list(x = 7L),
|
||
|
#' d = 2L, hidden_units = 128L
|
||
|
#' )
|
||
|
#'
|
||
|
#' @import methods tensorflow
|
||
|
#' @export nnsdr
|
||
|
#' @exportClass nnsdr
|
||
|
nnsdr <- setRefClass('nnsdr',
|
||
|
fields = list(
|
||
|
config = 'list',
|
||
|
nn.opg = 'ANY',
|
||
|
nn.ref = 'ANY',
|
||
|
history.opg = 'ANY',
|
||
|
history.ref = 'ANY',
|
||
|
B.opg = 'ANY',
|
||
|
B.ref = 'ANY',
|
||
|
history = function() {
|
||
|
if (is.null(.self$history.opg))
|
||
|
return(NULL)
|
||
|
|
||
|
history <- data.frame(
|
||
|
.self$history.opg,
|
||
|
model = factor('OPG'),
|
||
|
epoch = seq_len(nrow(.self$history.opg))
|
||
|
)
|
||
|
|
||
|
if (!is.null(.self$history.ref))
|
||
|
history <- rbind(history, data.frame(
|
||
|
.self$history.ref,
|
||
|
model = factor('Refinement'),
|
||
|
epoch = seq_len(nrow(.self$history.ref))
|
||
|
))
|
||
|
|
||
|
history
|
||
|
}
|
||
|
),
|
||
|
|
||
|
methods = list(
|
||
|
initialize = function(input_shapes, d, output_shape = 1L, ...) {
|
||
|
# Set configuration.
|
||
|
.self$config <- c(list(
|
||
|
input_shapes = input_shapes,
|
||
|
d = as.integer(d),
|
||
|
output_shape = output_shape
|
||
|
), list(...))
|
||
|
|
||
|
# Build OPG (Step 1) and Refinement (Step 2) Neuronal Networks
|
||
|
.self$nn.opg <- do.call(build.MLP, c(.self$config, list(
|
||
|
name = 'OPG', add_reduction = FALSE
|
||
|
)))
|
||
|
.self$nn.ref <- do.call(build.MLP, c(.self$config, list(
|
||
|
name = 'Refinement', add_reduction = TRUE
|
||
|
)))
|
||
|
|
||
|
# Set initial history field values. If and only if the `history.*`
|
||
|
# fields are `NULL`, then the Nets are NOT trained.
|
||
|
.self$history.opg <- NULL
|
||
|
.self$history.ref <- NULL
|
||
|
|
||
|
# Set (not jet available) reduction estimates
|
||
|
.self$B.opg <- NULL
|
||
|
.self$B.ref <- NULL
|
||
|
},
|
||
|
|
||
|
fit = function(inputs, output, epochs = 1L, batch_size = 32L,
|
||
|
initializer = c('random', 'fromOPG'), ..., verbose = 0L
|
||
|
) {
|
||
|
if (is.list(inputs)) {
|
||
|
inputs <- Map(tf$cast, as.list(inputs), dtype = 'float32')
|
||
|
} else {
|
||
|
inputs <- list(tf$cast(inputs, dtype = 'float32'))
|
||
|
}
|
||
|
initializer <- match.arg(initializer)
|
||
|
|
||
|
# Check for OPG history (Step 1), if available skip it.
|
||
|
if (is.null(.self$history.opg)) {
|
||
|
# Fit OPG Net and store training history.
|
||
|
hist <- .self$nn.opg$fit(inputs, output, ...,
|
||
|
epochs = as.integer(head(epochs, 1)),
|
||
|
batch_size = as.integer(head(batch_size, 1)),
|
||
|
verbose = as.integer(verbose)
|
||
|
)
|
||
|
.self$history.opg <- as.data.frame(hist$history)
|
||
|
} else if (verbose > 0) {
|
||
|
cat("OPG already trained -> skip OPG training.\n")
|
||
|
}
|
||
|
|
||
|
# Compute OPG estimate of the Reduction matrix 'B'.
|
||
|
# Always compute, different inputs change the estimate.
|
||
|
with(tf$GradientTape() %as% tape, {
|
||
|
tape$watch(inputs[[1]])
|
||
|
out <- .self$nn.opg(inputs)
|
||
|
})
|
||
|
G <- as.matrix(tape$gradient(out, inputs[[1]]))
|
||
|
B <- eigen(var(G), symmetric = TRUE)$vectors
|
||
|
B <- B[, 1:.self$config$d, drop = FALSE]
|
||
|
.self$B.opg <- B
|
||
|
|
||
|
# Check for need to initialize the Refinement Net.
|
||
|
if (is.null(.self$history.ref)) {
|
||
|
# Set Reduction layer
|
||
|
.self$nn.ref$get_layer('reduction')$set_weights(list(B))
|
||
|
|
||
|
# Check initialization (for random keep random initialization)
|
||
|
if (initializer == 'fromOPG') {
|
||
|
# Initialize Refinement Net weights from the OPG Net.
|
||
|
W <- as.array(.self$nn.opg$get_layer('hidden1')$kernel)
|
||
|
W <- rbind(
|
||
|
t(B) %*% W[1:nrow(B), , drop = FALSE],
|
||
|
W[-(1:nrow(B)), , drop = FALSE]
|
||
|
)
|
||
|
b <- as.array(.self$nn.opg$get_layer('hidden1')$bias)
|
||
|
.self$nn.ref$get_layer('hidden1')$set_weights(list(W, b))
|
||
|
# Get layer names with weights to be initialized from `nn.opg`
|
||
|
# These are the output layer and all hidden layers except the first
|
||
|
layer.names <- Filter(function(name) {
|
||
|
if (name == 'output') {
|
||
|
TRUE
|
||
|
} else if (name == 'hidden1') {
|
||
|
FALSE
|
||
|
} else {
|
||
|
startsWith(name, 'hidden')
|
||
|
}
|
||
|
}, lapply(.self$nn.opg$layers, `[[`, 'name'))
|
||
|
# Copy `nn.opg` weights to `nn.ref`
|
||
|
for (name in layer.names) {
|
||
|
.self$nn.ref$get_layer(name)$set_weights(lapply(
|
||
|
.self$nn.opg$get_layer(name)$weights, as.array
|
||
|
))
|
||
|
}
|
||
|
}
|
||
|
} else if (verbose > 0) {
|
||
|
cat("Refinement Net already trained -> continue training.\n")
|
||
|
}
|
||
|
|
||
|
# Fit (or continue fitting) the Refinement Net.
|
||
|
hist <- .self$nn.ref$fit(inputs, output, ...,
|
||
|
epochs = as.integer(tail(epochs, 1)),
|
||
|
batch_size = as.integer(tail(batch_size, 1)),
|
||
|
verbose = as.integer(verbose)
|
||
|
)
|
||
|
.self$history.ref <- rbind(
|
||
|
.self$history.ref,
|
||
|
as.data.frame(hist$history)
|
||
|
)
|
||
|
# Extract refined reduction estimate
|
||
|
.self$B.ref <- .self$nn.ref$get_layer('reduction')$get_weights()[[1]]
|
||
|
|
||
|
invisible(NULL)
|
||
|
},
|
||
|
predict = function(inputs) {
|
||
|
# Issue warning if the Refinement model (Step 2) used for prediction
|
||
|
# is not trained.
|
||
|
if (is.null(.self$history.ref))
|
||
|
warning('Refinement model not trained.')
|
||
|
|
||
|
if (is.list(inputs)) {
|
||
|
inputs <- Map(tf$cast, as.list(inputs), dtype = 'float32')
|
||
|
} else {
|
||
|
inputs <- list(tf$cast(inputs, dtype = 'float32'))
|
||
|
}
|
||
|
output <- .self$nn.ref(inputs)
|
||
|
|
||
|
if (is.list(output)) {
|
||
|
if (length(output) == 1L) {
|
||
|
as.array(output[[1]])
|
||
|
} else {
|
||
|
Map(as.array, output)
|
||
|
}
|
||
|
} else {
|
||
|
as.array(output)
|
||
|
}
|
||
|
},
|
||
|
coef = function(type = c('Refinement', 'OPG')) {
|
||
|
type <- match.arg(type)
|
||
|
if (type == 'Refinement') {
|
||
|
.self$B.ref
|
||
|
} else {
|
||
|
.self$B.opg
|
||
|
}
|
||
|
},
|
||
|
reset = function(reset = c('both', 'Refinement')) {
|
||
|
reset <- match.arg(reset)
|
||
|
if (reset == 'both') {
|
||
|
reinitialize_weights(.self$nn.opg)
|
||
|
reset_optimizer(.self$nn.opg$optimizer)
|
||
|
.self$history.opg <- NULL
|
||
|
.self$B.opg <- NULL
|
||
|
}
|
||
|
reinitialize_weights(.self$nn.ref)
|
||
|
reset_optimizer(.self$nn.ref$optimizer)
|
||
|
.self$history.ref <- NULL
|
||
|
.self$B.ref <- NULL
|
||
|
},
|
||
|
summary = function() {
|
||
|
.self$nn.opg$summary()
|
||
|
cat('\n')
|
||
|
.self$nn.ref$summary()
|
||
|
}
|
||
|
)
|
||
|
)
|
||
|
nnsdr$lock('config')
|