161 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			161 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
library(tensorPredictors)
 | 
						|
suppressPackageStartupMessages({
 | 
						|
    library(parallel)
 | 
						|
    library(pROC)
 | 
						|
})
 | 
						|
 | 
						|
#' Mode-Wise PCA preprocessing (generalized (2D)^2 PCA)
 | 
						|
#'
 | 
						|
#' @param npc_time      Number of Principal Components for time axis
 | 
						|
#' @param npc_sensor    Number of Principal Components for sensor axis
 | 
						|
#' @param npc_condition Number of Principal Components for stimulus condition axis
 | 
						|
preprocess <- function(X, npc_time, npc_sensor, npc_condition) {
 | 
						|
    # Mode covariances (for predictor and time point modes)
 | 
						|
    c(Sigma_t, Sigma_s, Sigma_c) %<-% mcov(X)
 | 
						|
 | 
						|
    # "predictor" (sensor) and time point principal components
 | 
						|
    V_t <- svd(Sigma_t, npc_time, 0L)$u
 | 
						|
    V_s <- svd(Sigma_s, npc_sensor, 0L)$u
 | 
						|
    V_c <- svd(Sigma_c, npc_condition, 0L)$u
 | 
						|
 | 
						|
    # reduce with mode wise PCs
 | 
						|
    mlm(X, list(V_t, V_s, V_c), modes = 1:3, transposed = TRUE)
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
### Classification performance measures
 | 
						|
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
 | 
						|
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
 | 
						|
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
 | 
						|
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
 | 
						|
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
 | 
						|
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
 | 
						|
# tpr: True positive rate.  P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
 | 
						|
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
 | 
						|
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
 | 
						|
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
 | 
						|
# tnr: True negative rate.  P(Yhat = - | Y = -).
 | 
						|
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
 | 
						|
# auc: Area Under the Curve
 | 
						|
auc <- function(y.true, y.pred) {
 | 
						|
    as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE, direction = "<")$auc)
 | 
						|
}
 | 
						|
auc.sd <- function(y.true, y.pred) {
 | 
						|
    sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE, direction = "<")))
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
#' Leave-one-out prediction using TSIR
 | 
						|
#'
 | 
						|
#' @param method reduction method to be applied
 | 
						|
#' @param X 3D EEG data (preprocessed or not)
 | 
						|
#' @param y binary responce vector
 | 
						|
#' @param ... additional arguments passed on to `method`
 | 
						|
loo.predict <- function(method, X, y, ...) {
 | 
						|
    # get method function name as character string for logging
 | 
						|
    method.name <- as.character(substitute(method))
 | 
						|
 | 
						|
    # Parallel Leave-One-Out prediction
 | 
						|
    unlist(parallel::mclapply(seq_along(y), function(i) {
 | 
						|
        # Fit with i'th observation removed
 | 
						|
        fit <- method(X[ , , , -i], y[-i], sample.axis = 4L, ...)
 | 
						|
 | 
						|
        # Reduce the entire data set
 | 
						|
        r <- as.vector(mlm(X, fit$betas, modes = 1:3, transpose = TRUE))
 | 
						|
        # Fit a logit model on reduced data with i'th observation removed
 | 
						|
        logit <- glm(y ~ r, family = binomial(link = "logit"),
 | 
						|
            data = data.frame(y = y[-i], r = r[-i])
 | 
						|
        )
 | 
						|
        # predict i'th response given i'th reduced observation
 | 
						|
        y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
 | 
						|
 | 
						|
        # report progress
 | 
						|
        cat(sprintf("%s - dim: (%d, %d, %d) - %3d/%d\n",
 | 
						|
            method.name, dim(X)[1], dim(X)[2], dim(X)[3], i, length(y)
 | 
						|
        ))
 | 
						|
 | 
						|
        y.hat
 | 
						|
    }, mc.cores = getOption("mc.cores", max(1L, parallel::detectCores() - 1L))))
 | 
						|
}
 | 
						|
 | 
						|
# "Projects" a sequence to its first `nr.freq` frequency components
 | 
						|
proj.fft <- function(sequence, nr.freq = 5L) {
 | 
						|
    F <- fft(sequence)
 | 
						|
    Re(fft(`[<-`(F, head(order(abs(F)), -nr.freq), 0+0i), inverse = TRUE)) / length(F)
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
# Load full EEG dataset (3D tensor for each subject)
 | 
						|
c(X, y) %<-% readRDS("eeg_data_3d.rds")
 | 
						|
 | 
						|
 | 
						|
##################################### GMLM #####################################
 | 
						|
 | 
						|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
 | 
						|
y.hat.3.4   <- loo.predict(gmlm_tensor_normal, preprocess(X,  3,  4, 3), y)
 | 
						|
y.hat.15.15 <- loo.predict(gmlm_tensor_normal, preprocess(X, 15, 15, 3), y)
 | 
						|
y.hat.20.30 <- loo.predict(gmlm_tensor_normal, preprocess(X, 20, 30, 3), y)
 | 
						|
y.hat       <- loo.predict(gmlm_tensor_normal, X, y)
 | 
						|
y.hat.fft   <- loo.predict(gmlm_tensor_normal, X, y, proj.betas = list(proj.fft, NULL, NULL))
 | 
						|
 | 
						|
# classification performance measures table by leave-one-out cross-validation
 | 
						|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat, y.hat.fft), 2, function(y.pred) {
 | 
						|
    sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
 | 
						|
        function(FUN) { match.fun(FUN)(as.integer(y) - 1L, y.pred) })
 | 
						|
}))
 | 
						|
#>         y.hat.3.4 y.hat.15.15 y.hat.20.30      y.hat  y.hat.fft
 | 
						|
#> acc    0.83606557  0.80327869  0.80327869 0.79508197 0.79508197
 | 
						|
#> err    0.16393443  0.19672131  0.19672131 0.20491803 0.20491803
 | 
						|
#> fpr    0.31111111  0.33333333  0.33333333 0.35555556 0.33333333
 | 
						|
#> tpr    0.92207792  0.88311688  0.88311688 0.88311688 0.87012987
 | 
						|
#> fnr    0.07792208  0.11688312  0.11688312 0.11688312 0.12987013
 | 
						|
#> tnr    0.68888889  0.66666667  0.66666667 0.64444444 0.66666667
 | 
						|
#> auc    0.88051948  0.86984127  0.86926407 0.86810967 0.86810967
 | 
						|
#> auc.sd 0.03118211  0.03254642  0.03259186 0.03295883 0.03354029
 | 
						|
 | 
						|
 | 
						|
################################## Tensor SIR ##################################
 | 
						|
 | 
						|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
 | 
						|
y.hat.3.4   <- loo.predict(TSIR, preprocess(X,  3,  4, 3), y)
 | 
						|
y.hat.15.15 <- loo.predict(TSIR, preprocess(X, 15, 15, 3), y)
 | 
						|
y.hat.20.30 <- loo.predict(TSIR, preprocess(X, 20, 30, 3), y)
 | 
						|
y.hat       <- loo.predict(TSIR, X, y)
 | 
						|
 | 
						|
# classification performance measures table by leave-one-out cross-validation
 | 
						|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
 | 
						|
    sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
 | 
						|
        function(FUN) { match.fun(FUN)(as.integer(y) - 1L, y.pred) })
 | 
						|
}))
 | 
						|
#>         y.hat.3.4 y.hat.15.15 y.hat.20.30      y.hat
 | 
						|
#> acc    0.81967213  0.84426230  0.81147541 0.80327869
 | 
						|
#> err    0.18032787  0.15573770  0.18852459 0.19672131
 | 
						|
#> fpr    0.33333333  0.24444444  0.33333333 0.33333333
 | 
						|
#> tpr    0.90909091  0.89610390  0.89610390 0.88311688
 | 
						|
#> fnr    0.09090909  0.10389610  0.10389610 0.11688312
 | 
						|
#> tnr    0.66666667  0.75555556  0.66666667 0.66666667
 | 
						|
#> auc    0.86522367  0.89379509  0.88196248 0.85974026
 | 
						|
#> auc.sd 0.03357539  0.03055047  0.02986038 0.03367847
 | 
						|
 | 
						|
 | 
						|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
 | 
						|
y.hat.3.4   <- loo.predict(TSIR, preprocess(X,  3,  4, 3), y, cond.threshold = 25)
 | 
						|
y.hat.15.15 <- loo.predict(TSIR, preprocess(X, 15, 15, 3), y, cond.threshold = 25)
 | 
						|
y.hat.20.30 <- loo.predict(TSIR, preprocess(X, 20, 30, 3), y, cond.threshold = 25)
 | 
						|
y.hat       <- loo.predict(TSIR, X, y, cond.threshold = 25)
 | 
						|
 | 
						|
# classification performance measures table by leave-one-out cross-validation
 | 
						|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
 | 
						|
    sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
 | 
						|
        function(FUN) { match.fun(FUN)(as.integer(y) - 1L, y.pred) })
 | 
						|
}))
 | 
						|
#>         y.hat.3.4 y.hat.15.15 y.hat.20.30      y.hat
 | 
						|
#> acc    0.81967213  0.77049180  0.76229508 0.77049180
 | 
						|
#> err    0.18032787  0.22950820  0.23770492 0.22950820
 | 
						|
#> fpr    0.33333333  0.37777778  0.40000000 0.37777778
 | 
						|
#> tpr    0.90909091  0.85714286  0.85714286 0.85714286
 | 
						|
#> fnr    0.09090909  0.14285714  0.14285714 0.14285714
 | 
						|
#> tnr    0.66666667  0.62222222  0.60000000 0.62222222
 | 
						|
#> auc    0.86522367  0.84386724  0.84415584 0.84040404
 | 
						|
#> auc.sd 0.03357539  0.03542706  0.03519592 0.03558135
 |