#' Trace of diffence/sum of left and right Kronecker product #' #' tr(A1 %x% ... %x% Ar - B1 %x% ... %x% Br) #' #' or for `sign == +1` it computes #' #' tr(A1 %x% ... %x% Ar + B1 %x% ... %x% Br) #' #' @examples #' A <- matrix(rnorm(5^2), 5) #' B <- matrix(rnorm(5^2), 5) #' stopifnot(all.equal( #' dist.kron.tr(list(A), list(B)), #' sum(diag(A - B)) #' )) #' stopifnot(all.equal( #' dist.kron.tr(list(A), list(B), +1), #' sum(diag(A + B)) #' )) #' #' A1 <- matrix(rnorm(5^2), 5) #' B1 <- matrix(rnorm(5^2), 5) #' A2 <- matrix(rnorm(7^2), 7) #' B2 <- matrix(rnorm(7^2), 7) #' #' stopifnot(all.equal( #' dist.kron.tr(list(A1, A2), list(B1, B2), -1), #' sum(diag(kronecker(A1, A2) - kronecker(B1, B2))) #' )) #' #' stopifnot(all.equal( #' dist.kron.tr(list(A1, A2), list(B1, B2), +1), #' sum(diag(kronecker(A1, A2) + kronecker(B1, B2))) #' )) #' #' p <- c(5, 3, 7, 2) #' As <- Map(function(pj) matrix(rnorm(pj^2), pj), p) #' Bs <- Map(function(pj) matrix(rnorm(pj^2), pj), p) #' stopifnot(all.equal( #' dist.kron.tr(As, Bs), #' sum(diag(Reduce(kronecker, As) - Reduce(kronecker, Bs))) #' )) #' stopifnot(all.equal( #' dist.kron.tr(As, Bs, +1), #' sum(diag(Reduce(kronecker, As) + Reduce(kronecker, Bs))) #' )) #' #' @export dist.kron.tr <- function(A, B, sign = -1) { # base case: trace of the difference (or sum for `sign == +1`) if ((is.matrix(A) || (length(A) == 1)) && (is.matrix(B) || (length(B) == 1))) { if (is.list(A)) A <- A[[1]] if (is.list(B)) B <- B[[1]] return(sum(diag(A)) + sign * sum(diag(B))) } # recursion failguard stopifnot(is.list(A) && is.list(B)) # Trace of A2 %x% A3 %x% ... %x% Ar and the same for B trA <- unlist(Map(function(C) sum(diag(C)), A)) trB <- unlist(Map(function(C) sum(diag(C)), B)) # recursive case: split of left most matrices from Kronecker product (sum(diag(A[[1]])) + sign * sum(diag(B[[1]]))) * dist.kron.tr(A[-1], B[-1], +1) - trA[1] * prod(trB[-1]) - sign * trB[1] * prod(trA[-1]) }