fix: typo in data_gen.cpp inverting quiete position sampling, add: white and black to move support, update: eeg data example, add: position analytics as interace to the schachhoernchen Board class
		
			
				
	
	
		
			120 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			R
		
	
	
	
	
	
library(tensorPredictors)
 | 
						|
 | 
						|
# Load as 3D predictors `X` and flat response `y` and `F = y` with per person dim. 1 x 1
 | 
						|
c(X, F, y) %<-% local({
 | 
						|
    # Load from file
 | 
						|
    ds <- readRDS("eeg_data.rds")
 | 
						|
 | 
						|
    # Dimension values
 | 
						|
    n <- nrow(ds)       # sample size (nr. of people)
 | 
						|
    p <- 64L            # nr. of predictors (count of sensorce)
 | 
						|
    t <- 256L           # nr. of time points (measurements)
 | 
						|
 | 
						|
    # Extract dimension names
 | 
						|
    nNames <- ds$PersonID
 | 
						|
    tNames <- as.character(seq(t))
 | 
						|
    pNames <- unlist(strsplit(colnames(ds)[2 + t * seq(p)], "_"))[c(TRUE, FALSE)]
 | 
						|
 | 
						|
    # Split into predictors (with proper dims and names) and response
 | 
						|
    X <- array(as.matrix(ds[, -(1:2)]),
 | 
						|
        dim = c(person = n, time = t, sensor = p),
 | 
						|
        dimnames = list(person = nNames, time = tNames, sensor = pNames)
 | 
						|
    )
 | 
						|
    y <- ds$Case_Control
 | 
						|
 | 
						|
    list(X, array(y, c(n, 1L, 1L)), y)
 | 
						|
})
 | 
						|
 | 
						|
# fit a tensor normal model to the data sample axis 1 indexes persons)
 | 
						|
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
 | 
						|
 | 
						|
# plot the fitted mode wise reductions (for time and sensor axis)
 | 
						|
with(fit.gmlm, {
 | 
						|
    par.reset <- par(mfrow = c(2, 1))
 | 
						|
    plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
 | 
						|
    plot(betas[[2]], main = "Sensors", xlab = "Sensor Index", ylab = expression(beta[2]))
 | 
						|
    par(par.reset)
 | 
						|
})
 | 
						|
 | 
						|
 | 
						|
#' (2D)^2 PCA preprocessing
 | 
						|
#'
 | 
						|
#' @param tpc Number of "t"ime "p"rincipal "c"omponents.
 | 
						|
#' @param ppc Number of "p"redictor "p"rincipal "c"omponents.
 | 
						|
preprocess <- function(X, tpc, ppc) {
 | 
						|
    # Mode covariances (for predictor and time point modes)
 | 
						|
    c(Sigma_t, Sigma_p) %<-% mcov(X, sample.axis = 1L)
 | 
						|
 | 
						|
    # "predictor" (sensor) and time point principal components
 | 
						|
    V_t <- svd(Sigma_t, tpc, 0L)$u
 | 
						|
    V_p <- svd(Sigma_p, ppc, 0L)$u
 | 
						|
 | 
						|
    # reduce with mode wise PCs
 | 
						|
    mlm(X, list(V_t, V_p), modes = 2:3, transposed = TRUE)
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
#' Leave-one-out prediction
 | 
						|
#'
 | 
						|
#' @param X 3D EEG data (preprocessed or not)
 | 
						|
#' @param F binary responce `y` as a 3D tensor, every obs. is a 1 x 1 matrix
 | 
						|
loo.predict <- function(X, F) {
 | 
						|
    sapply(seq_len(dim(X)[1L]), function(i) {
 | 
						|
        # Fit with i'th observation removes
 | 
						|
        fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
 | 
						|
 | 
						|
        # Reduce the entire data set
 | 
						|
        r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
 | 
						|
        # Fit a logit model on reduced data with i'th observation removed
 | 
						|
        logit <- glm(y ~ r, family = binomial(link = "logit"),
 | 
						|
            data = data.frame(y = y[-i], r = r[-i])
 | 
						|
        )
 | 
						|
        # predict i'th response given i'th reduced observation
 | 
						|
        y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
 | 
						|
        # report progress
 | 
						|
        cat(sprintf("dim: (%d, %d) - %3d/%d\n", dim(X)[2L], dim(X)[3L], i, dim(X)[1L]))
 | 
						|
 | 
						|
        y.hat
 | 
						|
    })
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
### Classification performance measures
 | 
						|
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
 | 
						|
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
 | 
						|
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
 | 
						|
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
 | 
						|
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
 | 
						|
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
 | 
						|
# tpr: True positive rate.  P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
 | 
						|
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
 | 
						|
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
 | 
						|
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
 | 
						|
# tnr: True negative rate.  P(Yhat = - | Y = -).
 | 
						|
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
 | 
						|
# auc: Area Under the Curve
 | 
						|
auc <- function(y.true, y.pred) as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE)$auc)
 | 
						|
auc.sd <- function(y.true, y.pred) sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE)))
 | 
						|
 | 
						|
 | 
						|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
 | 
						|
y.hat.3.4   <- loo.predict(preprocess(X,  3,  4), F)
 | 
						|
y.hat.15.15 <- loo.predict(preprocess(X, 15, 15), F)
 | 
						|
y.hat.20.30 <- loo.predict(preprocess(X, 20, 30), F)
 | 
						|
y.hat       <- loo.predict(X, F)
 | 
						|
 | 
						|
# classification performance measures table by leave-one-out cross-validation
 | 
						|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
 | 
						|
    sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
 | 
						|
        function(FUN) { match.fun(FUN)(y, y.pred) })
 | 
						|
}))
 | 
						|
#>         y.hat.3.4 y.hat.15.15 y.hat.20.30      y.hat
 | 
						|
#> acc    0.79508197  0.78688525  0.78688525 0.78688525
 | 
						|
#> err    0.20491803  0.21311475  0.21311475 0.21311475
 | 
						|
#> fpr    0.35555556  0.40000000  0.40000000 0.40000000
 | 
						|
#> tpr    0.88311688  0.89610390  0.89610390 0.89610390
 | 
						|
#> fnr    0.11688312  0.10389610  0.10389610 0.10389610
 | 
						|
#> tnr    0.64444444  0.60000000  0.60000000 0.60000000
 | 
						|
#> auc    0.85108225  0.83838384  0.83924964 0.83896104
 | 
						|
#> auc.sd 0.03584791  0.03760531  0.03751307 0.03754553
 |