Compare commits
	
		
			No commits in common. "8da1950d02290632c1753612607df0713251f36c" and "938e6bd3bad73323561a5c9c2a6418820c600ece" have entirely different histories.
		
	
	
		
			8da1950d02
			...
			938e6bd3ba
		
	
		
| @ -25,11 +25,11 @@ | |||||||
| #' @param y univariate response vector | #' @param y univariate response vector | ||||||
| #' @param lambda penalty term, if set to \code{Inf} max lambda is computed. | #' @param lambda penalty term, if set to \code{Inf} max lambda is computed. | ||||||
| #' @param loss loss function, part of the objective function | #' @param loss loss function, part of the objective function | ||||||
|  | #' @param grad.loss gradient with respect to \eqn{B} of the loss function | ||||||
|  | #'  (required, there is no support for numerical gradients) | ||||||
| #' @param penalty penalty function with a vector of the singular values if the | #' @param penalty penalty function with a vector of the singular values if the | ||||||
| #'  current iterate as arguments. The default function | #'  current iterate as arguments. The default function | ||||||
| #'  \code{function(sigma) sum(sigma)} is the nuclear norm penalty. | #'  \code{function(sigma) sum(sigma)} is the nuclear norm penalty. | ||||||
| #' @param grad.loss gradient with respect to \eqn{B} of the loss function |  | ||||||
| #'  (required, there is no support for numerical gradients) |  | ||||||
| #' @param shape Shape of the matrix valued predictors. Required iff the | #' @param shape Shape of the matrix valued predictors. Required iff the | ||||||
| #'  predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix. | #'  predictors \code{X} are provided in vectorized form, e.g. as a 2D matrix. | ||||||
| #' @param step.size max. stepsize for gradient updates | #' @param step.size max. stepsize for gradient updates | ||||||
| @ -38,23 +38,19 @@ | |||||||
| #' @param beta initial value of additional covatiates coefficient for \eqn{Z} | #' @param beta initial value of additional covatiates coefficient for \eqn{Z} | ||||||
| #' @param max.iter maximum number of gadient updates | #' @param max.iter maximum number of gadient updates | ||||||
| #' @param max.line.iter maximum number of line search iterations | #' @param max.line.iter maximum number of line search iterations | ||||||
| #' @param logger logging callback invoced after every line search before break |  | ||||||
| #'  condition checks. The expected function signature is of the form |  | ||||||
| #'  \code{logger(iter, loss, penalty, grad, B, beta, step.size)}. |  | ||||||
| #' | #' | ||||||
| #' @export | #' @export | ||||||
| RMReg <- function(X, Z, y, lambda = 0, | RMReg <- function(X, Z, y, lambda = 0, | ||||||
|     loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2), |     loss = function(B, beta, X, Z, y) 0.5 * sum((y - Z %*% beta - X %*% c(B))^2), | ||||||
|     penalty = function(sigma) sum(sigma), |  | ||||||
|     grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X), |     grad.loss = function(B, beta, X, Z, y) crossprod(X %*% c(B) + Z %*% beta - y, X), | ||||||
|  |     penalty = function(sigma) sum(sigma), | ||||||
|     shape = dim(X)[-1], |     shape = dim(X)[-1], | ||||||
|     step.size = 1e-3, |     step.size = 1e-3, | ||||||
|     alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 }, |     alpha = function(a, t) { (1 + sqrt(1 + (2 * a)^2)) / 2 }, | ||||||
|     B0 = array(0, dim = shape), |     B0 = array(0, dim = shape), | ||||||
|     beta = rep(0, NCOL(Z)), |     beta = rep(0, NCOL(Z)), | ||||||
|     max.iter = 500, |     max.iter = 500, | ||||||
|     max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)), |     max.line.iter = ceiling(log(step.size / sqrt(.Machine$double.eps), 2)) | ||||||
|     logger = NULL |  | ||||||
| ) { | ) { | ||||||
|     ### Check (prepair) params |     ### Check (prepair) params | ||||||
|     stopifnot(nrow(X) == length(y)) |     stopifnot(nrow(X) == length(y)) | ||||||
| @ -150,15 +146,8 @@ RMReg <- function(X, Z, y, lambda = 0, | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         # Evaluate loss at (potential) new parameters |  | ||||||
|         loss.temp <- loss(B.temp, beta, X, Z, y) |  | ||||||
| 
 |  | ||||||
|         # logging callback |  | ||||||
|         if (is.function(logger)) { |  | ||||||
|             logger(iter, loss.temp, penalty(b.temp), grad, B1, beta, delta) |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         # After gradient update enforce descent (stop if not decreasing) |         # After gradient update enforce descent (stop if not decreasing) | ||||||
|  |         loss.temp <- loss(B.temp, beta, X, Z, y) | ||||||
|         if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { |         if (loss.temp + penalty(b.temp) <= loss1 + penalty(b1)) { | ||||||
|             no.nesterov <- FALSE |             no.nesterov <- FALSE | ||||||
|             loss1 <- loss.temp |             loss1 <- loss.temp | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user