2
0
Fork 0

fix: predict_dim_cv always failes with dim. missmatch,

add: support for central method in predcit_dim_cv
This commit is contained in:
Daniel Kapla 2021-02-10 18:54:40 +01:00
parent e93bfdda05
commit 25b20984d5
1 changed files with 12 additions and 7 deletions

View File

@ -7,8 +7,8 @@ predict_dim_cv <- function(object) {
Sigma_root <- eig$vectors %*% tcrossprod(diag(sqrt(eig$values)), eig$vectors) Sigma_root <- eig$vectors %*% tcrossprod(diag(sqrt(eig$values)), eig$vectors)
X <- X %*% solve(Sigma_root) X <- X %*% solve(Sigma_root)
pred <- matrix(0, n, length(object$res)) pred <- array(0, c(n, ncol(object$Fy), length(object$res)),
colnames(pred) <- names(object$res) dimnames = list(NULL, NULL, names(object$res)))
for (dr.k in object$res) { for (dr.k in object$res) {
# get "name" of current dimension # get "name" of current dimension
k <- as.character(dr.k$k) k <- as.character(dr.k$k)
@ -16,12 +16,11 @@ predict_dim_cv <- function(object) {
X.proj <- X %*% dr.k$B X.proj <- X %*% dr.k$B
for (i in 1:n) { for (i in 1:n) {
model <- mda::mars(X.proj[-i, ], object$Y[-i]) model <- mda::mars(X.proj[-i, ], object$Fy[-i, ])
pred[i, k] <- predict(model, X.proj[i, , drop = F]) pred[i, , k] <- predict(model, X.proj[i, , drop = FALSE])
} }
} }
MSE <- colMeans((pred - object$Y)^2) MSE <- apply((pred - as.numeric(object$Fy))^2, 3, mean)
return(list( return(list(
MSE = MSE, MSE = MSE,
@ -30,6 +29,9 @@ predict_dim_cv <- function(object) {
} }
predict_dim_elbow <- function(object) { predict_dim_elbow <- function(object) {
if (ncol(object$Fy) > 1) # TODO: Implement or find better way
stop("For multivariate or central models not supported yet.")
# extract original data from object (cve result) # extract original data from object (cve result)
X <- object$X X <- object$X
Y <- object$Y Y <- object$Y
@ -71,6 +73,9 @@ predict_dim_elbow <- function(object) {
} }
predict_dim_wilcoxon <- function(object, p.value = 0.05) { predict_dim_wilcoxon <- function(object, p.value = 0.05) {
if (ncol(object$Fy) > 1) # TODO: Implement or find better way
stop("For multivariate or central models not supported yet.")
# extract original data from object (cve result) # extract original data from object (cve result)
X <- object$X X <- object$X
Y <- object$Y Y <- object$Y