205 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			205 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
#' Regularized Matrix Regression
 | 
						|
#'
 | 
						|
#' Solved the regularized problem
 | 
						|
#'      \deqn{min h(B) = l(B) + J(B)}
 | 
						|
#' for a matrix \eqn{B}.
 | 
						|
#' where \eqn{l} is a loss function; for the GLM, we use the negative
 | 
						|
#' log-likelihood as the loss. \eqn{J(B) = f(\sigma(B))}, where \eqn{f} is a
 | 
						|
#' function of the singular values of \eqn{B}.
 | 
						|
#'
 | 
						|
#' Currently, only the least squares problem with nuclear norm penalty is
 | 
						|
#' implemented.
 | 
						|
#'
 | 
						|
#' In case of \code{lambda = Inf} the maximum penalty \eqn{\lambda} is computed.
 | 
						|
#' In this case the return value is only estimate as a single value.
 | 
						|
#'
 | 
						|
#' @param X the singnal data ether as a 3D tensor or a 2D matrix. In case of a
 | 
						|
#'  3D tensor the axis are assumed to be \eqn{n\times p\times q} meaning the
 | 
						|
#'  first dimension are the observations while the second and third are the
 | 
						|
#'  `image' dimensions. When the data is provided as a matix it's assumed to be
 | 
						|
#'  of shape \eqn{n\times p q} where each observation is the vectorid `image'.
 | 
						|
#' @param Z additional covariate vector (can be \code{NULL} if not required.
 | 
						|
#'  For regression with intercept set \code{Z = rep(1, n)})
 | 
						|
#' @param y univariate response vector
 | 
						|
#' @param lambda penalty term, if set to \code{Inf} max lambda is computed.
 | 
						|
#' @param max.iter maximum number of gadient updates
 | 
						|
#' @param max.line.iter maximum number of line search iterations
 | 
						|
#' @param shape Shape of the matrix valued predictors. Required iff the
 | 
						|
#'  predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix.
 | 
						|
#' @param step.size max. stepsize for gradient updates
 | 
						|
#' @param B0 initial value for optimization. Matrix of dimensions \eqn{p\times q}
 | 
						|
#' @param beta0 initial value of additional covatiates coefficient for \eqn{Z}
 | 
						|
#' @param alpha iterative Nesterov momentum scaling values
 | 
						|
#' @param eps precition for main loop break conditions
 | 
						|
#' @param logger logging callback invoced after every line search before break
 | 
						|
#'  condition checks. The expected function signature is of the form
 | 
						|
#'  \code{function(iter, loss, penalty, B, beta, step.size)}.
 | 
						|
#'
 | 
						|
#' @export
 | 
						|
RMReg <- function(X, Z, y, lambda = 0, max.iter = 500L, max.line.iter = 50L,
 | 
						|
    shape = dim(X)[-1], step.size = 1e-3,
 | 
						|
    B0 = array(0, dim = shape),
 | 
						|
    beta0 = rep(0, NCOL(Z)),
 | 
						|
    alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 },
 | 
						|
    eps = .Machine$double.eps,
 | 
						|
    logger = NULL
 | 
						|
) {
 | 
						|
    # Define loss (without penalty)
 | 
						|
    loss <- function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2)
 | 
						|
    # gradient of loss (without penalty)
 | 
						|
    grad <- function(B, beta, X, Z, y) {
 | 
						|
        inner <- X %*% c(B) + Z %*% beta - y
 | 
						|
        list(beta = c(crossprod(inner, Z)), B = c(crossprod(inner, X)))
 | 
						|
    }
 | 
						|
    # # and the penalty function (as function of singular values)
 | 
						|
    # penalty <- function(sigma) sum(sigma)
 | 
						|
 | 
						|
    # Check (prepair) params
 | 
						|
    stopifnot(nrow(X) == length(y))
 | 
						|
    if (!missing(shape)) {
 | 
						|
        stopifnot(ncol(X) == prod(shape))
 | 
						|
    } else {
 | 
						|
        stopifnot(length(dim(X)) == 3)
 | 
						|
        dim(X) <- c(nrow(X), prod(shape))
 | 
						|
    }
 | 
						|
    if (missing(Z) || is.null(Z)) {
 | 
						|
        Z <- matrix(0, nrow(X), 1)
 | 
						|
    } else if (!is.matrix(Z)) {
 | 
						|
        Z <- as.matrix(Z)
 | 
						|
    }
 | 
						|
 | 
						|
    # Set singular values of start matrix predictor coefficients
 | 
						|
    if (missing(B0)) {
 | 
						|
        B1.sv <- rep(0, min(shape))
 | 
						|
    } else {
 | 
						|
        B1.sv <- La.svd(B0, 0, 0)$d
 | 
						|
    }
 | 
						|
    # initialize current and previous coefficients (start position)
 | 
						|
    B1 <- B0
 | 
						|
    beta1 <- beta0
 | 
						|
    alpha0 <- 0
 | 
						|
    alpha1 <- 1
 | 
						|
    loss0 <- loss1 <- loss(B1, beta1, X, Z, y)
 | 
						|
 | 
						|
    # main descent loop
 | 
						|
    no.nesterov <- FALSE
 | 
						|
    for (iter in seq_len(max.iter)) {
 | 
						|
        if (no.nesterov) {
 | 
						|
            # classic gradient step as fallback
 | 
						|
            S <- B1
 | 
						|
            s <- beta1
 | 
						|
        } else {
 | 
						|
            # momentum step (extrapolation using previous direction)
 | 
						|
            S <- B1 + ((alpha0 - 1) / alpha1) * (B1 - B0)
 | 
						|
            s <- beta1 + ((alpha0 - 1) / alpha1) * (beta1 - beta0)
 | 
						|
        }
 | 
						|
 | 
						|
        # compute (nesterov) gradient
 | 
						|
        G <- grad(S, s, X, Z, y)
 | 
						|
 | 
						|
        # backtracking line search (executed at least once)
 | 
						|
        for (delta in step.size * 0.5^seq(0, max.line.iter - 1L)) {
 | 
						|
            # Gradient step with step size delta
 | 
						|
            A <- S - delta * G$B
 | 
						|
            beta.temp <- s - delta * G$beta
 | 
						|
 | 
						|
            if (lambda == Inf) {
 | 
						|
                # Application of Corollary 1 for estimation of max lambda
 | 
						|
                # Return max lambda estimate
 | 
						|
                return(max(La.svd(A, 0, 0)$d) / delta)
 | 
						|
            } else if (lambda > 0) {
 | 
						|
                # SVD of (potential) next step
 | 
						|
                svdA <- La.svd(A)
 | 
						|
 | 
						|
                # Next (possible) penalized iterate
 | 
						|
                B.temp.sv <- pmax(0, svdA$d - delta * lambda)
 | 
						|
                B.temp <- svdA$u %*% (B.temp.sv * svdA$vt)
 | 
						|
            } else {
 | 
						|
                # in case of no penalization (pure least squares)
 | 
						|
                B.temp.sv <- La.svd(A, 0, 0)$d
 | 
						|
                B.temp <- A
 | 
						|
            }
 | 
						|
 | 
						|
            # Check line search condition
 | 
						|
            #   h(B.temp) <= g(B.temp | S, delta)
 | 
						|
            #  \_ left _/    \_____ right _____/
 | 
						|
            # where g(B.temp | S, delta) is the first order approx. of the loss
 | 
						|
            #   l(S) + <grad l(S), B - S> + | B - S |_F^2 / 2 delta + J(B)
 | 
						|
            left <- loss(B.temp, beta.temp, X, Z, y)
 | 
						|
            right <- loss(S, s, X, Z, y) +
 | 
						|
                sum(G$B * (B.temp - S)) + sum(G$beta * (beta.temp - s)) +
 | 
						|
                (norm(B.temp - S, 'F')^2 + sum((beta.temp - s)^2)) / (2 * delta)
 | 
						|
            if (left <= right) {
 | 
						|
                break
 | 
						|
            }
 | 
						|
        }
 | 
						|
 | 
						|
        # Evaluate loss to ensure descent after line search
 | 
						|
        loss.temp <- left # loss(B.temp, beta.temp, X, Z, y) # already computed
 | 
						|
 | 
						|
        # logging callback
 | 
						|
        if (is.function(logger)) {
 | 
						|
            logger(iter, loss.temp, lambda * sum(B.temp.sv),
 | 
						|
                   B.temp, beta.temp, delta)
 | 
						|
        }
 | 
						|
 | 
						|
        # after line search enforce descent
 | 
						|
        if (loss.temp + lambda * sum(B.temp.sv) <= loss1 + lambda * sum(B1.sv)) {
 | 
						|
            B0 <- B1
 | 
						|
            B1 <- array(B.temp, shape)
 | 
						|
            B1.sv <- B.temp.sv
 | 
						|
            beta0 <- beta1
 | 
						|
            beta1 <- beta.temp
 | 
						|
            loss0 <- loss1
 | 
						|
            loss1 <- loss.temp
 | 
						|
            no.nesterov <- FALSE    # always reset
 | 
						|
        } else if (!no.nesterov) {
 | 
						|
            no.nesterov <- TRUE     # retry without momentum
 | 
						|
            next
 | 
						|
        } else {
 | 
						|
            break                   # failed even without momentum -> stop
 | 
						|
        }
 | 
						|
 | 
						|
        # check break conditions
 | 
						|
        if (sum(B1.sv) < eps) {
 | 
						|
            break   # estimate is (numerically) zero -> stop
 | 
						|
        }
 | 
						|
        if ((sum(G$B^2) + sum(G$beta^2)) < eps * sum(unlist(Map(length, G)))) {
 | 
						|
            break   # mean squared gradient is smaller than epsilon -> stop
 | 
						|
        }
 | 
						|
        if (abs(loss0 - loss1) < eps) {
 | 
						|
            break   # decrease is smaller than epsilon -> stop
 | 
						|
        }
 | 
						|
 | 
						|
        # update momentum scaling
 | 
						|
        alpha0 <- alpha1
 | 
						|
        alpha1 <- alpha(alpha1, iter)
 | 
						|
 | 
						|
        # set step size to two times current delta
 | 
						|
        step.size <- 2 * delta
 | 
						|
    }
 | 
						|
 | 
						|
    # Degrees of Freedom estimate (TODO: this is like in `matrix_sparsereg.m`)
 | 
						|
    sigma <- c(La.svd(A, 0, 0)$d, rep(0, max(shape) - min(shape)))
 | 
						|
    df <- length(beta1)
 | 
						|
    for (i in seq_len(sum(B1.sv > 0))) {
 | 
						|
        df <- df + 1 + sigma[i] * (sigma[i] - delta * lambda) * (
 | 
						|
            sum(ifelse((1:shape[1]) != i, 1 / (sigma[i]^2 - sigma[1:shape[1]]^2), 0)) +
 | 
						|
            sum(ifelse((1:shape[2]) != i, 1 / (sigma[i]^2 - sigma[1:shape[2]]^2), 0)))
 | 
						|
    }
 | 
						|
 | 
						|
    # return estimates and some additional stats
 | 
						|
    list(
 | 
						|
        B = B1,
 | 
						|
        beta = beta1,
 | 
						|
        singular.values = B1.sv,
 | 
						|
        iter = iter,
 | 
						|
        df = df,
 | 
						|
        loss = loss1,
 | 
						|
        lambda = delta * lambda,    # 
 | 
						|
        AIC = loss1 + 2 * df,               # TODO: check this!
 | 
						|
        BIC = loss1 + log(nrow(X)) * df,    # TODO: check this!
 | 
						|
        call = match.call() # invocing function call, collects params like lambda
 | 
						|
    )
 | 
						|
}
 |