add: wip tensorflow implementation
This commit is contained in:
parent
70ceccb599
commit
4a950d6df2
|
@ -0,0 +1,169 @@
|
||||||
|
library(CVE)
|
||||||
|
library(reticulate)
|
||||||
|
library(tensorflow)
|
||||||
|
|
||||||
|
#' Null space basis of given matrix `V`
|
||||||
|
#'
|
||||||
|
#' @param V `(p, q)` matrix
|
||||||
|
#' @return Semi-orthogonal `(p, p - q)` matrix spaning the null space of `V`.
|
||||||
|
#' @keywords internal
|
||||||
|
#' @export
|
||||||
|
null <- function(V) {
|
||||||
|
tmp <- qr(V)
|
||||||
|
set <- if(tmp$rank == 0L) seq_len(ncol(V)) else -seq_len(tmp$rank)
|
||||||
|
return(qr.Q(tmp, complete = TRUE)[, set, drop = FALSE])
|
||||||
|
}
|
||||||
|
|
||||||
|
subspace_dist <- function(A, B) {
|
||||||
|
P <- A %*% solve(t(A) %*% A, t(A))
|
||||||
|
Q <- B %*% solve(t(B) %*% B, t(B))
|
||||||
|
norm(P - Q, 'F') / sqrt(ncol(A) + ncol(B))
|
||||||
|
}
|
||||||
|
|
||||||
|
estimate.bandwidth <- function (X, k, nObs = sqrt(nrow(X)), version = 1L) {
|
||||||
|
n <- nrow(X)
|
||||||
|
p <- ncol(X)
|
||||||
|
X_c <- scale(X, center = TRUE, scale = FALSE)
|
||||||
|
|
||||||
|
if (version == 1) {
|
||||||
|
(2 * sum(X_c^2) / (n * p)) * (1.2 * n^(-1 / (4 + k)))^2
|
||||||
|
} else if (version == 2) {
|
||||||
|
2 * qchisq((nObs - 1) / (n - 1), k) * sum(X_c^2) / (n * p)
|
||||||
|
} else {
|
||||||
|
stop("Unknown version.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tf_Variable <- function(obj, dtype = "float32", ...) {
|
||||||
|
tf$Variable(obj, dtype = dtype, ...)
|
||||||
|
}
|
||||||
|
tf_constant <- function(obj, dtype = "float32", ...) {
|
||||||
|
tf$constant(obj, dtype = dtype, ...)
|
||||||
|
}
|
||||||
|
|
||||||
|
cve.tf <- function(X, Y, k, h = estimate.bandwidth(X, k, sqrt(nrow(X))),
|
||||||
|
V.init = NULL, optimizer_initialier = tf$optimizers$RMSprop, attempts = 10L,
|
||||||
|
sd_noise = 0, method = c("simple", "weighted")
|
||||||
|
) {
|
||||||
|
method <- match.arg(method)
|
||||||
|
|
||||||
|
`-0.5` <- tf_constant(-0.5)
|
||||||
|
`1` <- tf_constant(1)
|
||||||
|
`2` <- tf_constant(2)
|
||||||
|
n <- nrow(X)
|
||||||
|
p <- ncol(X)
|
||||||
|
k <- as.integer(k)
|
||||||
|
q <- p - k
|
||||||
|
|
||||||
|
X <- tf_constant(scale(X))
|
||||||
|
Y <- tf_constant(scale(as.matrix(Y)))
|
||||||
|
I <- tf_constant(diag(1, p))
|
||||||
|
h <- tf_Variable(h)
|
||||||
|
|
||||||
|
loss <- tf_function(function(V) {
|
||||||
|
Q <- I - tf$matmul(V, V, transpose_b = TRUE)
|
||||||
|
if (sd_noise > 0)
|
||||||
|
XQ <- tf$matmul(X + tf$random$normal(list(n, p), stddev = 0.05), Q)
|
||||||
|
else
|
||||||
|
XQ <- tf$matmul(X, Q)
|
||||||
|
S <- tf$matmul(XQ, XQ, transpose_b = TRUE)
|
||||||
|
d <- tf$linalg$diag_part(S)
|
||||||
|
D <- tf$reshape(d, list(n, 1L)) + tf$reshape(d, list(1L, n)) - `2` * S
|
||||||
|
K <- tf$exp((`-0.5` / h) * tf$pow(D, 2L))
|
||||||
|
w <- tf$reduce_sum(K, 1L, keepdims = TRUE)
|
||||||
|
y1 <- tf$divide(tf$matmul(K, Y), w)
|
||||||
|
y2 <- tf$divide(tf$matmul(K, tf$pow(Y, 2L)), w)
|
||||||
|
if (method == "simple") {
|
||||||
|
l <- tf$reduce_mean(y2 - tf$pow(y1, 2L))
|
||||||
|
} else {# weighted
|
||||||
|
w <- tf$reduce_sum(K, 1L, keepdims = TRUE) - `1` # TODO: check/fix
|
||||||
|
w <- w / tf$reduce_sum(w)
|
||||||
|
l <- tf$reduce_sum(w * (y2 - tf$pow(y1, 2L)))
|
||||||
|
}
|
||||||
|
l
|
||||||
|
})
|
||||||
|
|
||||||
|
if (is.null(V.init))
|
||||||
|
V.init <- qr.Q(qr(matrix(rnorm(p * q), p, q)))
|
||||||
|
else
|
||||||
|
attempts <- 1L
|
||||||
|
V <- tf_Variable(V.init, constraint = function(w) { tf$linalg$qr(w)$q })
|
||||||
|
|
||||||
|
min.loss <- Inf
|
||||||
|
for (attempt in seq_len(attempts)) {
|
||||||
|
optimizer = optimizer_initialier()
|
||||||
|
|
||||||
|
out <- tf$while_loop(
|
||||||
|
cond = tf_function(function(i, L) i < 400L),
|
||||||
|
body = tf_function(function(i, L) {
|
||||||
|
with(tf$GradientTape() %as% tape, {
|
||||||
|
tape$watch(V)
|
||||||
|
L <- loss(V)
|
||||||
|
})
|
||||||
|
grad <- tape$gradient(L, V)
|
||||||
|
optimizer$apply_gradients(list(list(grad, V)))
|
||||||
|
|
||||||
|
list(i + 1L, L)
|
||||||
|
}),
|
||||||
|
loop_vars = list(tf_constant(0L, "int32"), tf_constant(Inf))
|
||||||
|
)
|
||||||
|
|
||||||
|
if (as.numeric(out[[2]]) < min.loss) {
|
||||||
|
min.loss <- as.numeric(out[[2]])
|
||||||
|
min.V <- as.matrix(V)
|
||||||
|
}
|
||||||
|
V$assign(qr.Q(qr(matrix(rnorm(p * q), p, q))))
|
||||||
|
}
|
||||||
|
|
||||||
|
list(B = null(min.V), V = min.V, loss = min.loss)
|
||||||
|
}
|
||||||
|
# ds <- dataset(1)
|
||||||
|
# out <- cve.call2(ds$X, ds$Y, ncol(ds$B))
|
||||||
|
|
||||||
|
plot.sim <- function(sim) {
|
||||||
|
name <- deparse(substitute(sim))
|
||||||
|
ssd <- sapply(sim, function(s) subspace_dist(s$B.true, s$B.est))
|
||||||
|
print(summary(ssd))
|
||||||
|
h <- hist(ssd, freq = FALSE, breaks = seq(0, 1, 0.1), main = name,
|
||||||
|
xlab = "Subspace Distance")
|
||||||
|
lines(density(ssd, from = 0, to = 1))
|
||||||
|
stat <- c(Median = median(ssd), Mean = mean(ssd))
|
||||||
|
abline(v = stat, lty = 2)
|
||||||
|
text(stat, 1.02 * max(h$density), names(stat),
|
||||||
|
pos = if(diff(stat) > 0) c("2", "4") else c("4", "2"))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
set.seed(42)
|
||||||
|
sim.cve <- vector("list", 100)
|
||||||
|
sim.tf1 <- vector("list", 100)
|
||||||
|
sim.tf2 <- vector("list", 100)
|
||||||
|
|
||||||
|
start <- Sys.time()
|
||||||
|
for (i in 1:100) {
|
||||||
|
ds <- dataset(1)
|
||||||
|
|
||||||
|
sim.cve[[i]] <- list(
|
||||||
|
B.est = coef(CVE::cve.call(ds$X, ds$Y, k = ncol(ds$B)), ncol(ds$B)),
|
||||||
|
B.true = ds$B
|
||||||
|
)
|
||||||
|
|
||||||
|
sim.tf1[[i]] <- list(
|
||||||
|
B.est = cve.tf(ds$X, ds$Y, ncol(ds$B))$B,
|
||||||
|
B.true = ds$B
|
||||||
|
)
|
||||||
|
|
||||||
|
sim.tf2[[i]] <- list(
|
||||||
|
B.est = cve.tf(ds$X, ds$Y, ncol(ds$B), sd_noise = 0.05)$B,
|
||||||
|
B.true = ds$B
|
||||||
|
)
|
||||||
|
|
||||||
|
cat(sprintf("\r%4d/100 -", i), format(Sys.time() - start), '\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
# pdf('subspace_comp.pdf')
|
||||||
|
par(mfrow = c(3, 1))
|
||||||
|
plot.sim(sim.cve)
|
||||||
|
plot.sim(sim.tf1)
|
||||||
|
plot.sim(sim.tf2)
|
||||||
|
# dev.off()
|
Loading…
Reference in New Issue