70 lines
2.0 KiB
R
70 lines
2.0 KiB
R
#' Trace of diffence/sum of left and right Kronecker product
|
|
#'
|
|
#' tr(A1 %x% ... %x% Ar - B1 %x% ... %x% Br)
|
|
#'
|
|
#' or for `sign == +1` it computes
|
|
#'
|
|
#' tr(A1 %x% ... %x% Ar + B1 %x% ... %x% Br)
|
|
#'
|
|
#' @examples
|
|
#' A <- matrix(rnorm(5^2), 5)
|
|
#' B <- matrix(rnorm(5^2), 5)
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(list(A), list(B)),
|
|
#' sum(diag(A - B))
|
|
#' ))
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(list(A), list(B), +1),
|
|
#' sum(diag(A + B))
|
|
#' ))
|
|
#'
|
|
#' A1 <- matrix(rnorm(5^2), 5)
|
|
#' B1 <- matrix(rnorm(5^2), 5)
|
|
#' A2 <- matrix(rnorm(7^2), 7)
|
|
#' B2 <- matrix(rnorm(7^2), 7)
|
|
#'
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(list(A1, A2), list(B1, B2), -1),
|
|
#' sum(diag(kronecker(A1, A2) - kronecker(B1, B2)))
|
|
#' ))
|
|
#'
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(list(A1, A2), list(B1, B2), +1),
|
|
#' sum(diag(kronecker(A1, A2) + kronecker(B1, B2)))
|
|
#' ))
|
|
#'
|
|
#' p <- c(5, 3, 7, 2)
|
|
#' As <- Map(function(pj) matrix(rnorm(pj^2), pj), p)
|
|
#' Bs <- Map(function(pj) matrix(rnorm(pj^2), pj), p)
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(As, Bs),
|
|
#' sum(diag(Reduce(kronecker, As) - Reduce(kronecker, Bs)))
|
|
#' ))
|
|
#' stopifnot(all.equal(
|
|
#' dist.kron.tr(As, Bs, +1),
|
|
#' sum(diag(Reduce(kronecker, As) + Reduce(kronecker, Bs)))
|
|
#' ))
|
|
#'
|
|
#' @export
|
|
dist.kron.tr <- function(A, B, sign = -1) {
|
|
# base case: trace of the difference (or sum for `sign == +1`)
|
|
if ((is.matrix(A) || (length(A) == 1))
|
|
&& (is.matrix(B) || (length(B) == 1))) {
|
|
if (is.list(A)) A <- A[[1]]
|
|
if (is.list(B)) B <- B[[1]]
|
|
|
|
return(sum(diag(A)) + sign * sum(diag(B)))
|
|
}
|
|
|
|
# recursion failguard
|
|
stopifnot(is.list(A) && is.list(B))
|
|
|
|
# Trace of A2 %x% A3 %x% ... %x% Ar and the same for B
|
|
trA <- unlist(Map(function(C) sum(diag(C)), A))
|
|
trB <- unlist(Map(function(C) sum(diag(C)), B))
|
|
|
|
# recursive case: split of left most matrices from Kronecker product
|
|
(sum(diag(A[[1]])) + sign * sum(diag(B[[1]]))) * dist.kron.tr(A[-1], B[-1], +1) -
|
|
trA[1] * prod(trB[-1]) - sign * trB[1] * prod(trA[-1])
|
|
}
|