289 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			289 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Specialized version of the GMLM for the Ising model (inverse Ising problem)
 | 
						|
#'
 | 
						|
#' @todo TODO: Add beta and Omega projections
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
 | 
						|
    proj.betas = NULL, proj.Omegas = NULL, Omega.mask = NULL,
 | 
						|
    max.iter = 1000L,
 | 
						|
    eps = sqrt(.Machine$double.eps),
 | 
						|
    step.size = 1e-3,
 | 
						|
    zig.zag.threashold = 20L,
 | 
						|
    patience = 3L,
 | 
						|
    nr.slices = 20L,                            # only for univariate `F(y) = y`
 | 
						|
    slice.method = c("cut", "ecdf", "none"),    # only for univariate `F(y) = y` and `y` is a factor or integer
 | 
						|
    use_MC = 20L <= prod(dim(X)[-sample.axis]),
 | 
						|
    nr_threads = 8L,                            # ignored if `use_MC` is `FALSE`
 | 
						|
    logger = function(...) { }
 | 
						|
) {
 | 
						|
    # Get problem dimensions
 | 
						|
    dimX <- dim(X)[-sample.axis]
 | 
						|
 | 
						|
    if (is.function(F)) {
 | 
						|
        # compute `F(y)`, replace function `F` with its tensor result
 | 
						|
        F <- F(y)
 | 
						|
        dimF <- dim(F)[-sample.axis]
 | 
						|
    } else if (is.null(dim(F))) {
 | 
						|
        # threat scalar `F` as a tensor
 | 
						|
        dimF <- rep(1L, length(dimX))
 | 
						|
        dim(F) <- ifelse(seq_along(dim(X)) == sample.axis, sample.size, 1L)
 | 
						|
    } else {
 | 
						|
        # `F` already provided as tensor
 | 
						|
        dimF <- dim(F)[-sample.axis]
 | 
						|
    }
 | 
						|
    sample.size <- dim(X)[sample.axis]
 | 
						|
 | 
						|
    # rearrange `X`, `F` such that the last axis enumerates observations
 | 
						|
    if (sample.axis != length(dim(X))) {
 | 
						|
        axis.perm <- c(seq_along(dim(X))[-sample.axis], sample.axis)
 | 
						|
        X <- aperm(X, axis.perm)
 | 
						|
        F <- aperm(F, axis.perm)
 | 
						|
        sample.axis <- length(dim(X))
 | 
						|
    }
 | 
						|
    modes <- seq_along(dimX)
 | 
						|
 | 
						|
    # Ensure the Omega and beta projections lists are lists
 | 
						|
    if (!is.list(proj.Omegas)) {
 | 
						|
        proj.Omegas <- rep(NULL, length(modes))
 | 
						|
    }
 | 
						|
    if (!is.list(proj.betas)) {
 | 
						|
        proj.betas <- rep(NULL, length(modes))
 | 
						|
    }
 | 
						|
 | 
						|
    # Special case for univariate response `y` or univariate `F = F(y)`
 | 
						|
    # Due to high computational costs we use slicing
 | 
						|
    slice.method <- match.arg(slice.method)
 | 
						|
    if (slice.method == "none") {
 | 
						|
        # slicing "turned off"
 | 
						|
        slices.ind <- seq_len(sample.size)
 | 
						|
    } else {
 | 
						|
        # get slicing variable, ether by providing `y` of if `F` is univariate
 | 
						|
        y <- if (length(y) == sample.size) {
 | 
						|
            as.vector(y)
 | 
						|
        } else if (length(F) == sample.size) {
 | 
						|
            as.vector(F)
 | 
						|
        } else {
 | 
						|
            NULL
 | 
						|
        }
 | 
						|
 | 
						|
        if (is.null(y)) {
 | 
						|
            # couldn't find univariate variable to slice
 | 
						|
            slices.ind <- seq_len(sample.size)
 | 
						|
        } else {
 | 
						|
            # compute slice indices depending on type
 | 
						|
            if (!(is.factor(y) || is.integer(y))) {
 | 
						|
                if (slice.method == "ecdf") {
 | 
						|
                    y <- cut(ecdf(y)(y), nr.slices)
 | 
						|
                } else {
 | 
						|
                    y <- cut(y, nr.slices)
 | 
						|
                }
 | 
						|
            }
 | 
						|
            slices.ind <- split(seq_len(sample.size), y, drop = TRUE)
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    # initialize betas with tensor normal estimate (ignoring data being binary)
 | 
						|
    # (do NOT use the Omega projections, the tensor normal `Omegas` do not match
 | 
						|
    # the interpretation of the Ising model `Omegas`)
 | 
						|
    fit_normal <- gmlm_tensor_normal(X, F, sample.axis = length(dim(X)),
 | 
						|
        proj.betas = proj.betas)
 | 
						|
    betas <- fit_normal$betas
 | 
						|
 | 
						|
    Omegas <- Omegas.init <- Map(function(mode) {
 | 
						|
        n <- prod(dim(X)[-mode])
 | 
						|
        prob2 <- mcrossprod(X, mode = mode) / n
 | 
						|
        prob2[prob2 == 0] <- 1 / n
 | 
						|
        prob2[prob2 == 1] <- (n - 1) / n
 | 
						|
        prob1 <- diag(prob2)
 | 
						|
        `prob1^2` <- outer(prob1, prob1)
 | 
						|
 | 
						|
        `diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
 | 
						|
    }, modes)
 | 
						|
 | 
						|
    # Project `Omegas` onto their respective manifolds (`betas` already handled)
 | 
						|
    for (j in modes) {
 | 
						|
        if (is.function(proj_j <- proj.Omegas[[j]])) {
 | 
						|
            Omegas[[j]] <- proj_j(Omegas[[j]])
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    # Determin degenerate combinations, that are variables which are exclusive
 | 
						|
    # in the data set
 | 
						|
    matX <- mat(X, sample.axis)
 | 
						|
    degen <- crossprod(matX) == 0
 | 
						|
    degen.mask <- which(degen)
 | 
						|
    # If there are degenerate combination, compute an (arbitrary) bound of the
 | 
						|
    # log odds parameters of those combinations
 | 
						|
    if (any(degen.mask)) {
 | 
						|
        degen.ind <- arrayInd(degen.mask, dim(degen))
 | 
						|
        meanX <- colMeans(matX)
 | 
						|
        prodX <- meanX[degen.ind[, 1]] * meanX[degen.ind[, 2]]
 | 
						|
        degen.bounds <- log((1 - prodX) / (prodX * sample.size))
 | 
						|
        # Component indices in Omegas of degenerate two-way interactions
 | 
						|
        degen.ind <- arrayInd(degen.mask, rep(dimX, 2))
 | 
						|
        degen.ind <- Map(function(d, m) {
 | 
						|
            degen.ind[, m] + dimX[m] * (degen.ind[, m + length(dimX)] - 1L)
 | 
						|
        }, dimX, seq_along(dimX))
 | 
						|
 | 
						|
        ## Enforce initial value degeneracy interaction param. constraints
 | 
						|
        # Extract parameters corresponding to degenerate interactions
 | 
						|
        degen.params <- do.call(rbind, Map(`[`, Omegas, degen.ind))
 | 
						|
        # Degeneracy Constrained Parameters (sign is dropped)
 | 
						|
        DCP <- mapply(function(vals, bound) {
 | 
						|
            logVals <- log(abs(vals))
 | 
						|
            err <- max(0, sum(logVals) - log(abs(bound)))
 | 
						|
            exp(logVals - (err / length(vals)))
 | 
						|
        }, split(degen.params, col(degen.params)), degen.bounds)
 | 
						|
        # Update values in Omegas such that all degeneracy constraints hold
 | 
						|
        Omegas <- Map(function(Omega, cp, ind) {
 | 
						|
            # Combine multiple constraints for every element into single
 | 
						|
            # constraint value per element
 | 
						|
            cp <- mapply(min, split(abs(cp), ind))
 | 
						|
            ind <- as.integer(names(cp))
 | 
						|
            `[<-`(Omega, ind, sign(Omega[ind]) * cp)
 | 
						|
        }, Omegas, split(DCP, row(DCP)), degen.ind)
 | 
						|
    }
 | 
						|
 | 
						|
    # Initialize mean squared gradients
 | 
						|
    grad2_betas  <- Map(array, 0, Map(dim, betas))
 | 
						|
    grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
 | 
						|
 | 
						|
    # Keep track of the last loss to accumulate loss difference sign changes
 | 
						|
    # indicating optimization instabilities as a sign to stop
 | 
						|
    last_loss <- Inf
 | 
						|
    accum_sign <- 1
 | 
						|
 | 
						|
    # non improving iteration counter
 | 
						|
    non_improving <- 0L
 | 
						|
 | 
						|
    # technical access points to dynamicaly access a multi-dimensional array
 | 
						|
    `X[..., i]` <- slice.expr(X, sample.axis, index = i, drop = FALSE)
 | 
						|
    `F[..., i]` <- slice.expr(F, sample.axis, index = i, drop = FALSE)
 | 
						|
 | 
						|
    # Iterate till a break condition triggers or till max. nr. of iterations
 | 
						|
    for (iter in seq_len(max.iter)) {
 | 
						|
 | 
						|
        grad_betas <- Map(matrix, 0, dimX, dimF)
 | 
						|
        Omega <- Reduce(kronecker, rev(Omegas))
 | 
						|
 | 
						|
        # Mask Omega, that is to enforce the "linear" constraint `T2`
 | 
						|
        if (!is.null(Omega.mask)) {
 | 
						|
            Omega[Omega.mask] <- 0
 | 
						|
        }
 | 
						|
 | 
						|
        # second order residuals accumulator
 | 
						|
        # `sum_i (X_i o X_i - E[X o X | Y = y_i])`
 | 
						|
        R2 <- array(0, dim = c(dimX, dimX))
 | 
						|
 | 
						|
        # negative log-likelihood
 | 
						|
        loss <- 0
 | 
						|
 | 
						|
        for (i in slices.ind) {
 | 
						|
            # slice size (nr. of objects in the slice)
 | 
						|
            n_i <- length(i)
 | 
						|
 | 
						|
            sumF_i <- `dim<-`(rowSums(eval(`F[..., i]`), dims = length(dimF)), dimF)
 | 
						|
 | 
						|
            diag_params_i <- mlm(sumF_i / n_i, betas)
 | 
						|
            params_i <- Omega + diag(as.vector(diag_params_i))
 | 
						|
            m2_i <- ising_m2(params_i, use_MC = use_MC, nr_threads = nr_threads)
 | 
						|
 | 
						|
            # accumulate loss
 | 
						|
            matX_i <- mat(eval(`X[..., i]`), modes)
 | 
						|
            loss <- loss - (
 | 
						|
                sum(matX_i * (params_i %*% matX_i)) + n_i * attr(m2_i, "log_prob_0")
 | 
						|
            )
 | 
						|
 | 
						|
            R2_i <- tcrossprod(matX_i) - n_i * m2_i
 | 
						|
            R1_i <- diag(R2_i)
 | 
						|
            dim(R1_i) <- dimX
 | 
						|
 | 
						|
            for (j in modes) {
 | 
						|
                grad_betas[[j]] <- grad_betas[[j]] +
 | 
						|
                    mcrossprod(R1_i, mlm(sumF_i, betas[-j], modes[-j]), j)
 | 
						|
            }
 | 
						|
            R2 <- R2 + as.vector(R2_i)
 | 
						|
        }
 | 
						|
 | 
						|
        # Apply the `T2` constraint on the Residuals as well (refer to `T2`)
 | 
						|
        # That is, we compute G2 from g2 as in Theorem 2.
 | 
						|
        if (!is.null(Omega.mask)) {
 | 
						|
            R2[Omega.mask] <- 0
 | 
						|
        }
 | 
						|
 | 
						|
        grad_Omegas <- Map(function(j) {
 | 
						|
            grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-j]), modes[-j], transposed = TRUE)
 | 
						|
            dim(grad) <- dim(Omegas[[j]])
 | 
						|
            grad
 | 
						|
        }, modes)
 | 
						|
 | 
						|
 | 
						|
        # update optimization behavioral trackers
 | 
						|
        accum_sign <- sign(last_loss - loss) - accum_sign
 | 
						|
        non_improving <- max(0L, non_improving - 1L + 2L * (last_loss < loss))
 | 
						|
 | 
						|
        # check break conditions
 | 
						|
        if (abs(accum_sign) > zig.zag.threashold) { break }
 | 
						|
        if (non_improving > patience) { break }
 | 
						|
        if (abs(last_loss - loss) < eps * last_loss) { break }
 | 
						|
 | 
						|
        # store current loss for the next iteration
 | 
						|
        last_loss <- loss
 | 
						|
 | 
						|
        # Accumulate root mean squared gradiends
 | 
						|
        grad2_betas  <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g),
 | 
						|
            grad2_betas, grad_betas)
 | 
						|
        grad2_Omegas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g),
 | 
						|
            grad2_Omegas, grad_Omegas)
 | 
						|
 | 
						|
        # logging (before parameter update)
 | 
						|
        logger(iter, loss, betas, Omegas, grad_betas, grad_Omegas)
 | 
						|
 | 
						|
        # Update Parameters
 | 
						|
        betas <- Map(function(beta, grad, m2) {
 | 
						|
            beta + (step.size / (sqrt(m2) + eps)) * grad
 | 
						|
        }, betas, grad_betas, grad2_betas)
 | 
						|
        Omegas <- Map(function(Omega, grad, m2) {
 | 
						|
            Omega + (step.size / (sqrt(m2) + eps)) * grad
 | 
						|
        }, Omegas, grad_Omegas, grad2_Omegas)
 | 
						|
 | 
						|
        # Project Parameters onto their manifolds
 | 
						|
        for (j in modes) {
 | 
						|
            if (is.function(proj_j <- proj.betas[[j]])) {
 | 
						|
                betas[[j]] <- proj_j(betas[[j]])
 | 
						|
            }
 | 
						|
            if (is.function(proj_j <- proj.Omegas[[j]])) {
 | 
						|
                Omegas[[j]] <- proj_j(Omegas[[j]])
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        # Enforce degeneracy parameter constraints
 | 
						|
        if (any(degen.mask)) {
 | 
						|
            # Extract parameters corresponding to degenerate interactions
 | 
						|
            degen.params <- do.call(rbind, Map(`[`, Omegas, degen.ind))
 | 
						|
            # Degeneracy Constrained Parameters (sign is dropped)
 | 
						|
            DCP <- mapply(function(vals, bound) {
 | 
						|
                logVals <- log(abs(vals))
 | 
						|
                err <- max(0, sum(logVals) - log(abs(bound)))
 | 
						|
                exp(logVals - (err / length(vals)))
 | 
						|
            }, split(degen.params, col(degen.params)), degen.bounds)
 | 
						|
            # Update values in Omegas such that all degeneracy constraints hold
 | 
						|
            Omegas <- Map(function(Omega, cp, ind) {
 | 
						|
                # Combine multiple constraints for every element into single
 | 
						|
                # constraint value per element
 | 
						|
                cp <- mapply(min, split(abs(cp), ind))
 | 
						|
                ind <- as.integer(names(cp))
 | 
						|
                `[<-`(Omega, ind, sign(Omega[ind]) * cp)
 | 
						|
            }, Omegas, split(DCP, row(DCP)), degen.ind)
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
    structure(
 | 
						|
        list(eta1 = array(0, dimX), betas = betas, Omegas = Omegas),
 | 
						|
        tensor_normal = fit_normal,
 | 
						|
        Omegas.init = Omegas.init,
 | 
						|
        degen.mask = degen.mask,
 | 
						|
        iter = iter
 | 
						|
    )
 | 
						|
}
 |