#' Matricization
#'
#' @param T multi-dimensional array
#' @param modes axis indices along to matricize
#' @param dims dimension of \code{T} befor matricization
#' @param inv boolean to determin if the inverse operation should be performed
#'
#' @returns matrix of dimensions \code{dims[modes]} by \code{prod(dims)[-modes]}
#'  or tensor of dimensions \code{dims} iff \code{inv} is true.
#'
#' @examples
#' stopifnot(all.equal(
#'     mat(1:12, 2, dims = c(2, 3, 2)),
#'     matrix(c(
#'         1,  2,  7,  8,
#'         3,  4,  9, 10,
#'         5,  6, 11, 12
#'     ), 3, 4, byrow = TRUE)
#' ))
#'
#' A <- array(rnorm(2 * 3 * 5), dim = c(2, 3, 5))
#' stopifnot(exprs = {
#'     all.equal(A, mat(mat(A, 1), 1, dim(A), TRUE))
#'     all.equal(A, mat(mat(A, 2), 2, dim(A), TRUE))
#'     all.equal(A, mat(mat(A, 3), 3, dim(A), TRUE))
#'     all.equal(A, mat(mat(A, c(1, 2)), c(1, 2), dim(A), TRUE))
#'     all.equal(A, mat(mat(A, c(1, 3)), c(1, 3), dim(A), TRUE))
#'     all.equal(A, mat(mat(A, c(2, 3)), c(2, 3), dim(A), TRUE))
#'
#'     all.equal(t(mat(A, 1)), mat(A, c(2, 3)))
#'     all.equal(t(mat(A, 3)), mat(A, c(1, 2)))
#' })
#'
#' @export
mat <- function(T, modes, dims = dim(T), inv = FALSE) {
    modes <- as.integer(modes)

    stopifnot(exprs = {
        length(T) == prod(dims)
        all(modes <= length(dims))
    })

    perm <- c(modes, seq_along(dims)[-modes])
    if (inv) {
        dim(T) <- dims[perm]
        perm <- order(perm)
    } else {
        dim(T) <- dims
    }

    T <- aperm(T, perm)

    if (inv) {
        dim(T) <- dims
    } else {
        dim(T) <- c(prod(dims[modes]), prod(dims[-modes]))
    }

    T
}


# #' Inverse Matricization
# #'
# #' @param T matrix of dimensions \code{dims[mode]} by \code{prod(dims[-mode])}
# #' @param mode axis along the original matricization
# #' @param dims dimension of the original tensor
# #'
# #' @returns multi-dimensional array of dimensions \code{dims}
# #'
# #' @examples
# #' p <- c(2, 3, 5)
# #' A <- array(rnorm(prod(p)), dim = p)
# #' stopifnot(expr = {
# #'     all.equal(A, mat.inv(mat(A, 1), 1, p))
# #'     all.equal(A, mat.inv(mat(A, 2), 2, p))
# #'     all.equal(A, mat.inv(mat(A, 3), 3, p))
# #' })
# #'
# #' @export
# mat.inv <- function(T, modes, dims) {
#     modes <- as.integer(modes)

#     stopifnot(exprs = {
#         length(T) == prod(dims)
#         any(length(dims) < modes)
#     })

#     dim(T) <- c(dims[modes], dims[-modes])
#     T <- aperm(T, order(c(modes, seq_along(dims)[-modes])))
#     dim(T) <- c(prod(dims[modes]), prod(dims[-modes]))

#     T
# }