tensor_predictors/tensorPredictors/R/dist_kron_tr.R

70 lines
2.0 KiB
R

#' Trace of diffence/sum of left and right Kronecker product
#'
#' tr(A1 %x% ... %x% Ar - B1 %x% ... %x% Br)
#'
#' or for `sign == +1` it computes
#'
#' tr(A1 %x% ... %x% Ar + B1 %x% ... %x% Br)
#'
#' @examples
#' A <- matrix(rnorm(5^2), 5)
#' B <- matrix(rnorm(5^2), 5)
#' stopifnot(all.equal(
#' dist.kron.tr(list(A), list(B)),
#' sum(diag(A - B))
#' ))
#' stopifnot(all.equal(
#' dist.kron.tr(list(A), list(B), +1),
#' sum(diag(A + B))
#' ))
#'
#' A1 <- matrix(rnorm(5^2), 5)
#' B1 <- matrix(rnorm(5^2), 5)
#' A2 <- matrix(rnorm(7^2), 7)
#' B2 <- matrix(rnorm(7^2), 7)
#'
#' stopifnot(all.equal(
#' dist.kron.tr(list(A1, A2), list(B1, B2), -1),
#' sum(diag(kronecker(A1, A2) - kronecker(B1, B2)))
#' ))
#'
#' stopifnot(all.equal(
#' dist.kron.tr(list(A1, A2), list(B1, B2), +1),
#' sum(diag(kronecker(A1, A2) + kronecker(B1, B2)))
#' ))
#'
#' p <- c(5, 3, 7, 2)
#' As <- Map(function(pj) matrix(rnorm(pj^2), pj), p)
#' Bs <- Map(function(pj) matrix(rnorm(pj^2), pj), p)
#' stopifnot(all.equal(
#' dist.kron.tr(As, Bs),
#' sum(diag(Reduce(kronecker, As) - Reduce(kronecker, Bs)))
#' ))
#' stopifnot(all.equal(
#' dist.kron.tr(As, Bs, +1),
#' sum(diag(Reduce(kronecker, As) + Reduce(kronecker, Bs)))
#' ))
#'
#' @export
dist.kron.tr <- function(A, B, sign = -1) {
# base case: trace of the difference (or sum for `sign == +1`)
if ((is.matrix(A) || (length(A) == 1))
&& (is.matrix(B) || (length(B) == 1))) {
if (is.list(A)) A <- A[[1]]
if (is.list(B)) B <- B[[1]]
return(sum(diag(A)) + sign * sum(diag(B)))
}
# recursion failguard
stopifnot(is.list(A) && is.list(B))
# Trace of A2 %x% A3 %x% ... %x% Ar and the same for B
trA <- unlist(Map(function(C) sum(diag(C)), A))
trB <- unlist(Map(function(C) sum(diag(C)), B))
# recursive case: split of left most matrices from Kronecker product
(sum(diag(A[[1]])) + sign * sum(diag(B[[1]]))) * dist.kron.tr(A[-1], B[-1], +1) -
trA[1] * prod(trB[-1]) - sign * trB[1] * prod(trA[-1])
}