add: custom metrics support,
add: MAVE subspace distance measure
This commit is contained in:
parent
69a008535b
commit
c3757cb055
|
@ -4,6 +4,7 @@ S3method(coef,nnsdr)
|
||||||
S3method(summary,nnsdr)
|
S3method(summary,nnsdr)
|
||||||
export(dataset)
|
export(dataset)
|
||||||
export(dist.grassmann)
|
export(dist.grassmann)
|
||||||
|
export(dist.mave)
|
||||||
export(dist.subspace)
|
export(dist.subspace)
|
||||||
export(get.script)
|
export(get.script)
|
||||||
export(nnsdr)
|
export(nnsdr)
|
||||||
|
@ -12,6 +13,7 @@ export(reinitialize_weights)
|
||||||
export(reset_optimizer)
|
export(reset_optimizer)
|
||||||
exportClasses(nnsdr)
|
exportClasses(nnsdr)
|
||||||
import(methods)
|
import(methods)
|
||||||
|
import(reticulate)
|
||||||
import(stats)
|
import(stats)
|
||||||
import(tensorflow)
|
import(tensorflow)
|
||||||
importFrom(stats,rbinom)
|
importFrom(stats,rbinom)
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
#' Subspace distance mentioned in [Xia et al, 2002] (first MAVE paper).
|
||||||
|
#'
|
||||||
|
#' @param A,B Basis matrices (assumed full rank) as representations of elements
|
||||||
|
#' of the Grassmann manifold.
|
||||||
|
#' @param is.ortho Boolean to specify if `A` and `B` are semi-orthogonal. If
|
||||||
|
#' false, a QR decomposition is used to orthogonalize both `A` and `B`.
|
||||||
|
#'
|
||||||
|
#' @seealso
|
||||||
|
#' Y. Xia and H. Tong and W.K. Li and L. Zhu (2002) "An adaptive estimation of
|
||||||
|
#' dimension reduction space" <DOI:10.1111/1467-9868.03411>
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
dist.mave <- function(A, B, is.ortho = FALSE) {
|
||||||
|
if (!is.matrix(A)) A <- as.matrix(A)
|
||||||
|
if (!is.matrix(B)) B <- as.matrix(B)
|
||||||
|
|
||||||
|
if (!is.ortho) {
|
||||||
|
A <- qr.Q(qr(A))
|
||||||
|
B <- qr.Q(qr(B))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ncol(A) < ncol(B)) {
|
||||||
|
norm((diag(nrow(A)) - tcrossprod(B, B)) %*% A, 'F')
|
||||||
|
} else {
|
||||||
|
norm((diag(nrow(A)) - tcrossprod(A, A)) %*% B, 'F')
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,18 +10,20 @@ Sys.setenv(TF_CPP_MIN_LOG_LEVEL = "3")
|
||||||
#' @param add_reduction TODO:
|
#' @param add_reduction TODO:
|
||||||
#' @param hidden_units TODO:
|
#' @param hidden_units TODO:
|
||||||
#' @param activation TODO:
|
#' @param activation TODO:
|
||||||
|
#' @param output_activation TODO:
|
||||||
#' @param dropout TODO:
|
#' @param dropout TODO:
|
||||||
#' @param loss TODO:
|
#' @param loss TODO:
|
||||||
#' @param optimizer TODO:
|
#' @param optimizer TODO:
|
||||||
#' @param metrics TODO:
|
#' @param metrics TODO:
|
||||||
#' @param trainable_reduction TODO:
|
#' @param trainable_reduction TODO:
|
||||||
#'
|
#'
|
||||||
#' @import tensorflow
|
#' @import reticulate tensorflow
|
||||||
#' @keywords internal
|
#' @keywords internal
|
||||||
build.MLP <- function(input_shapes, d, name, add_reduction,
|
build.MLP <- function(input_shapes, d, name, add_reduction,
|
||||||
output_shape = 1L,
|
output_shape = 1L,
|
||||||
hidden_units = 512L,
|
hidden_units = 512L,
|
||||||
activation = 'relu',
|
activation = 'relu',
|
||||||
|
output_activation = NULL,
|
||||||
dropout = 0.4,
|
dropout = 0.4,
|
||||||
loss = 'MSE',
|
loss = 'MSE',
|
||||||
optimizer = 'RMSProp',
|
optimizer = 'RMSProp',
|
||||||
|
@ -63,9 +65,25 @@ build.MLP <- function(input_shapes, d, name, add_reduction,
|
||||||
if (dropout > 0)
|
if (dropout > 0)
|
||||||
out <- K$layers$Dropout(rate = dropout, name = paste0('dropout', i))(out)
|
out <- K$layers$Dropout(rate = dropout, name = paste0('dropout', i))(out)
|
||||||
}
|
}
|
||||||
out <- K$layers$Dense(units = output_shape, name = 'output')(out)
|
out <- K$layers$Dense(units = output_shape, activation = output_activation,
|
||||||
|
name = 'output')(out)
|
||||||
|
|
||||||
mlp <- K$models$Model(inputs = inputs, outputs = out, name = name)
|
mlp <- K$models$Model(inputs = inputs, outputs = out, name = name)
|
||||||
|
|
||||||
|
if (!is.null(metrics)) {
|
||||||
|
metrics <- as.list(metrics)
|
||||||
|
for (i in seq_along(metrics)) {
|
||||||
|
metric <- metrics[[i]]
|
||||||
|
if (all(c("nnsdr.metric", name) %in% class(metric))) {
|
||||||
|
metric_fn <- reticulate::py_func(metric(mlp))
|
||||||
|
reticulate::py_set_attr(metric_fn, "__name__", attr(metric, "name"))
|
||||||
|
metrics[[i]] <- metric_fn
|
||||||
|
} else if ("nnsdr.metric" %in% class(metric)) {
|
||||||
|
metrics[[i]] <- NULL # Drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mlp$compile(loss = loss, optimizer = optimizer, metrics = metrics)
|
mlp$compile(loss = loss, optimizer = optimizer, metrics = metrics)
|
||||||
|
|
||||||
mlp
|
mlp
|
||||||
|
@ -95,18 +113,27 @@ nnsdr <- setRefClass('nnsdr',
|
||||||
if (is.null(.self$history.opg))
|
if (is.null(.self$history.opg))
|
||||||
return(NULL)
|
return(NULL)
|
||||||
|
|
||||||
history <- data.frame(
|
if (is.null(.self$history.ref)) {
|
||||||
.self$history.opg,
|
history <- data.frame(
|
||||||
model = factor('OPG'),
|
.self$history.opg,
|
||||||
epoch = seq_len(nrow(.self$history.opg))
|
model = factor('OPG')
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
if (!is.null(.self$history.ref))
|
hist.opg <- data.frame(
|
||||||
history <- rbind(history, data.frame(
|
.self$history.opg,
|
||||||
|
model = factor('OPG')
|
||||||
|
)
|
||||||
|
hist.ref <- data.frame(
|
||||||
.self$history.ref,
|
.self$history.ref,
|
||||||
model = factor('Refinement'),
|
model = factor('Refinement')
|
||||||
epoch = seq_len(nrow(.self$history.ref))
|
)
|
||||||
))
|
# Augment mutualy exclusive columns
|
||||||
|
hist.opg[setdiff(names(hist.ref), names(hist.opg))] <- NA
|
||||||
|
hist.ref[setdiff(names(hist.opg), names(hist.ref))] <- NA
|
||||||
|
# Combine/Bind
|
||||||
|
history <- rbind(hist.opg, hist.ref)
|
||||||
|
}
|
||||||
|
history$epoch <- seq_len(nrow(history))
|
||||||
|
|
||||||
history
|
history
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue