changed PSQTs,
fix: typo in data_gen.cpp inverting quiete position sampling, add: white and black to move support, update: eeg data example, add: position analytics as interace to the schachhoernchen Board class
This commit is contained in:
parent
61bd94bec8
commit
daefd3e7d1
|
@ -6,17 +6,53 @@ HCE <- function(positions) {
|
|||
.Call(`_Rchess_HCE`, positions)
|
||||
}
|
||||
|
||||
#' Given a FEN (position) determines if its whites turn
|
||||
isWhiteTurn <- function(positions) {
|
||||
.Call(`_Rchess_isWhiteTurn`, positions)
|
||||
}
|
||||
|
||||
#' Check if current side to move is in check
|
||||
isCheck <- function(positions) {
|
||||
.Call(`_Rchess_isCheck`, positions)
|
||||
}
|
||||
|
||||
#' Check if the current position is a quiet position (no piece is attacked)
|
||||
isQuiet <- function(positions) {
|
||||
.Call(`_Rchess_isQuiet`, positions)
|
||||
}
|
||||
|
||||
#' Check if position is terminal
|
||||
#'
|
||||
#' Checks if the position is a terminal position, meaning if the game ended
|
||||
#' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||
#' checked, therefore a seperate game history is required which the board
|
||||
#' does NOT track.
|
||||
#'
|
||||
isTerminal <- function(positions) {
|
||||
.Call(`_Rchess_isTerminal`, positions)
|
||||
}
|
||||
|
||||
#' Check if checkmate is possible by material on the board
|
||||
#'
|
||||
#' Checks if there is sufficient mating material on the board, meaning if it
|
||||
#' possible for any side to deliver a check mate. More specifically, it
|
||||
#' checks if the pieces on the board are KK, KNK or KBK.
|
||||
#'
|
||||
isInsufficient <- function(positions) {
|
||||
.Call(`_Rchess_isInsufficient`, positions)
|
||||
}
|
||||
|
||||
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
|
||||
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
||||
#' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||
#'
|
||||
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L) {
|
||||
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count)
|
||||
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
|
||||
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
|
||||
}
|
||||
|
||||
#' Human Crafted Evaluation
|
||||
eval.psqt <- function(positions, psqt) {
|
||||
.Call(`_Rchess_eval_psqt`, positions, psqt)
|
||||
eval.psqt <- function(positions, psqt, pawn_structure = FALSE, eval_rooks = FALSE, eval_king = FALSE) {
|
||||
.Call(`_Rchess_eval_psqt`, positions, psqt, pawn_structure, eval_rooks, eval_king)
|
||||
}
|
||||
|
||||
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
|
||||
|
|
|
@ -946,7 +946,7 @@ Score Board::evalPawns(enum piece color) const {
|
|||
if (color == white) {
|
||||
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
||||
Index i = bitScanLS(sq);
|
||||
score += pieceSquareTables[pawn][i];
|
||||
score += PSQT[pawn][i];
|
||||
}
|
||||
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
||||
u64 backwards = pawns & ~isolated;
|
||||
|
@ -963,7 +963,7 @@ Score Board::evalPawns(enum piece color) const {
|
|||
} else { // color == black
|
||||
for (u64 sq = pawns; sq; sq &= sq - 1) {
|
||||
Index i = bitScanLS(sq);
|
||||
score += pieceSquareTables[pawn][63 - i];
|
||||
score += PSQT[pawn][63 - i];
|
||||
}
|
||||
// Backwards pawns (not isolated but behind all adjacent friendly pawns)
|
||||
u64 backwards = pawns & ~isolated;
|
||||
|
@ -988,7 +988,7 @@ Score Board::evalKingSafety(enum piece color) const {
|
|||
? bitScanLS( _bitBoard[white] & _bitBoard[king])
|
||||
: bitScanLS(bitFlip<Rank>(_bitBoard[black] & _bitBoard[king]));
|
||||
|
||||
Score score = pieceSquareTables[king][kingSq];
|
||||
Score score = PSQT[king][kingSq];
|
||||
|
||||
if ((fileIndex(kingSq) < 3) || (4 < fileIndex(kingSq))) { // King is castled
|
||||
// Pawn shields are the least advanced pawns per file
|
||||
|
@ -1043,7 +1043,7 @@ Score Board::evalRooks(enum piece color) const {
|
|||
for (u64 sq = rooks; sq; sq &= sq - 1) {
|
||||
Index sqIndex = bitScanLS(sq);
|
||||
// Piece square table (accounts for rook on seventh bonus)
|
||||
score += pieceSquareTables[rook][sqIndex];
|
||||
score += PSQT[rook][sqIndex];
|
||||
|
||||
// Add bonuses for semi-open and open files
|
||||
if (bitMask<Square>(sqIndex) & openFiles) {
|
||||
|
@ -1063,22 +1063,10 @@ Score Board::evalRooks(enum piece color) const {
|
|||
// position fen r1bq1rk1/pp2p1bp/2np4/5B2/nP3P2/N1P2N2/6PP/R1B1QRK1 b - - 0 4
|
||||
Score Board::evaluate() const {
|
||||
|
||||
constexpr Score pstKingEndgame[64] = { // TODO: Proper parameters file, ...
|
||||
0, 10, 20, 30, 30, 20, 10, 0,
|
||||
10, 20, 30, 40, 40, 30, 20, 10,
|
||||
20, 30, 40, 50, 50, 40, 30, 20,
|
||||
30, 40, 50, 60, 60, 50, 40, 30,
|
||||
30, 40, 50, 60, 60, 50, 40, 30,
|
||||
20, 30, 40, 50, 50, 40, 30, 20,
|
||||
10, 20, 30, 40, 40, 30, 20, 10,
|
||||
0, 10, 20, 30, 30, 20, 10, 0
|
||||
};
|
||||
|
||||
constexpr Score maxMaterial = 8 * pieceValues[pawn]
|
||||
+ 2 * (pieceValues[rook] + pieceValues[knight] + pieceValues[bishop])
|
||||
+ pieceValues[queen];
|
||||
|
||||
|
||||
// Start score with material values
|
||||
Score whiteMaterial = evalMaterial(white);
|
||||
Score blackMaterial = evalMaterial(black);
|
||||
|
@ -1089,11 +1077,11 @@ Score Board::evaluate() const {
|
|||
for (enum piece type : { queen, bishop, knight }) {
|
||||
// White pieces
|
||||
for (u64 sq = _bitBoard[white] & _bitBoard[type]; sq; sq &= sq - 1) {
|
||||
score += pieceSquareTables[type][bitScanLS(sq)];
|
||||
score += PSQT[type][bitScanLS(sq)];
|
||||
}
|
||||
// and black pieces
|
||||
for (u64 sq = _bitBoard[black] & _bitBoard[type]; sq; sq &= sq - 1) {
|
||||
score -= pieceSquareTables[type][63 - bitScanLS(sq)];
|
||||
score -= PSQT[type][63 - bitScanLS(sq)];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1107,7 +1095,7 @@ Score Board::evaluate() const {
|
|||
if (blackMaterial <= 1200) {
|
||||
// Endgame:
|
||||
// No king safety, but more king mobility in the center
|
||||
score += pstKingEndgame[bitScanLS(_bitBoard[white] & _bitBoard[king])];
|
||||
score += PSQT[kingEG][bitScanLS(_bitBoard[white] & _bitBoard[king])];
|
||||
} else {
|
||||
// Middle Game:
|
||||
// King safety weighted by opponents material (The less pieces the enemy
|
||||
|
@ -1117,7 +1105,7 @@ Score Board::evaluate() const {
|
|||
}
|
||||
// and the same for the black king with opposite sign
|
||||
if (whiteMaterial <= 1200) {
|
||||
score -= pstKingEndgame[bitScanLS(_bitBoard[black] & _bitBoard[king])];
|
||||
score -= PSQT[kingEG][63 - bitScanLS(_bitBoard[black] & _bitBoard[king])];
|
||||
} else {
|
||||
score -= (5 * whiteMaterial * evalKingSafety(black)) / (4 * maxMaterial);
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ public:
|
|||
// add 128 to ensure the PST values are positive
|
||||
const Index t = color() == white ? to() : 63 - to();
|
||||
const Index f = color() == white ? from() : 63 - from();
|
||||
const uint32_t pst = pieceSquareTables[piece()][t] - pieceSquareTables[piece()][f] + 128;
|
||||
const uint32_t pst = PSQT[piece()][t] - PSQT[piece()][f] + 128;
|
||||
|
||||
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
|
||||
}
|
||||
|
|
|
@ -39,7 +39,8 @@ enum piece {
|
|||
bishop = 4,
|
||||
rook = 5,
|
||||
queen = 6,
|
||||
king = 7
|
||||
king = 7,
|
||||
kingEG = 8 // Lookup index for king end game PSQT
|
||||
};
|
||||
|
||||
enum square : Index {
|
||||
|
@ -69,10 +70,10 @@ enum location {
|
|||
constexpr Score pieceValues[8] = {
|
||||
0, 0, // white, black (irrelevant)
|
||||
100, // pawn
|
||||
300, // knight
|
||||
300, // bishop
|
||||
500, // rook
|
||||
900, // queen
|
||||
295, // knight
|
||||
315, // bishop
|
||||
450, // rook
|
||||
870, // queen
|
||||
0 // king (irrelevant, always 2 opposite kings)
|
||||
};
|
||||
|
||||
|
@ -132,64 +133,76 @@ constexpr u64 kingMoveLookup[64] = {
|
|||
using std::cerr;
|
||||
#endif
|
||||
|
||||
// Piece Square tables (from TSCP)
|
||||
// see: https://www.chessprogramming.org/Simplified_Evaluation_Function
|
||||
constexpr Score pieceSquareTables[8][64] = {
|
||||
// Piece SQuare Tables (partially automated tuned tables via supervised
|
||||
// optimization using stockfish [https://stockfishchess.org/] evaluated positions
|
||||
// from the lichess database [https://database.lichess.org/])
|
||||
// endgame table: https://www.chessprogramming.org/Simplified_Evaluation_Function
|
||||
// Which is addapted by adding 50. then scaled by 2 / 3 and rounded.
|
||||
constexpr Score PSQT[9][64] = {
|
||||
{ }, { }, // white, black (empty)
|
||||
{ // pawn (white)
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
5, 10, 15, 20, 20, 15, 10, 5,
|
||||
4, 8, 12, 16, 16, 12, 8, 4,
|
||||
3, 6, 9, 12, 12, 9, 6, 3,
|
||||
2, 4, 6, 8, 8, 6, 4, 2,
|
||||
1, 2, 3, -10, -10, 3, 2, 1,
|
||||
0, 0, 0, -40, -40, 0, 0, 0,
|
||||
109, 82, 89, 25, 25, 89, 82, 109,
|
||||
21, 18, -3, 18, 18, -3, 18, 21,
|
||||
-12, -1, -19, 6, 6, -19, -1, -12,
|
||||
-25, -15, -22, 9, 9, -22, -15, -25,
|
||||
-25, -11, -27, -23, -23, -27, -11, -25,
|
||||
-25, -13, -23, -29, -29, -23, -13, -25,
|
||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
||||
{ // knight (white)
|
||||
-10, -10, -10, -10, -10, -10, -10, -10,
|
||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
||||
-10, -30, -10, -10, -10, -10, -30, -10 },
|
||||
-90, -80, -18, 26, 26, -18, -80, -90,
|
||||
-40, -13, 21, -22, -22, 21, -13, -40,
|
||||
6, 2, 32, 38, 38, 32, 2, 6,
|
||||
-9, -11, 22, 20, 20, 22, -11, -9,
|
||||
-13, -11, 14, 2, 2, 14, -11, -13,
|
||||
-25, -10, 2, 3, 3, 2, -10, -25,
|
||||
-21, -54, -12, -8, -8, -12, -54, -21,
|
||||
-76, -21, -38, -34, -34, -38, -21, -76 },
|
||||
{ // bishop (white)
|
||||
-10, -10, -10, -10, -10, -10, -10, -10,
|
||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
||||
-10, 0, 5, 10, 10, 5, 0, -10,
|
||||
-10, 0, 5, 5, 5, 5, 0, -10,
|
||||
-10, 0, 0, 0, 0, 0, 0, -10,
|
||||
-10, -10, -20, -10, -10, -20, -10, -10 },
|
||||
-7, 19, 3, -21, -21, 3, 19, -7,
|
||||
-15, -5, 6, 40, 40, 6, -5, -15,
|
||||
12, 14, 18, 32, 32, 18, 14, 12,
|
||||
-5, -2, 17, 26, 26, 17, -2, -5,
|
||||
-19, -2, 2, 8, 8, 2, -2, -19,
|
||||
2, 4, 2, 8, 8, 2, 4, 2,
|
||||
-4, 8, 3, 1, 1, 3, 8, -4,
|
||||
-31, -13, -7, -20, -20, -7, -13, -31 },
|
||||
{ // rook (white)
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
20, 20, 20, 20, 20, 20, 20, 20, // rook on seventh bonus
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
||||
-5, -2, 23, 40, 40, 23, -2, -5,
|
||||
18, 17, 42, 25, 25, 42, 17, 18,
|
||||
22, 14, 33, 40, 40, 33, 14, 22,
|
||||
21, 16, 20, 28, 28, 20, 16, 21,
|
||||
-4, -13, -5, 3, 3, -5, -13, -4,
|
||||
-20, -2, -3, -2, -2, -3, -2, -20,
|
||||
-11, -13, 0, -6, -6, 0, -13, -11,
|
||||
-17, -4, 0, 7, 7, 0, -4, -17 },
|
||||
{ // queen (white)
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0 },
|
||||
-55, -29, 59, 19, 19, 59, -29, -55,
|
||||
12, -18, 34, 85, 85, 34, -18, 12,
|
||||
33, 17, 31, 34, 34, 31, 17, 33,
|
||||
51, 16, 21, 18, 18, 21, 16, 51,
|
||||
-3, 24, 18, 26, 26, 18, 24, -3,
|
||||
11, 14, 24, 2, 2, 24, 14, 11,
|
||||
28, 5, 17, 15, 15, 17, 5, 28,
|
||||
1, -10, -14, 18, 18, -14, -10, 1 },
|
||||
{ // king middle game (white)
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-40, -40, -40, -40, -40, -40, -40, -40,
|
||||
-20, -20, -20, -20, -20, -20, -20, -20,
|
||||
0, 20, 40, -20, 0, -20, 40, 20 }
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||
-4, -4, -4, -4, -4, -4, -4, -4,
|
||||
24, 13, 3, -28, 2, -14, 15, 1 },
|
||||
{ // king end game (white) // TODO: self/supervised tuning
|
||||
0, 7, 13, 20, 20, 13, 7, 0,
|
||||
13, 20, 27, 33, 33, 27, 20, 13,
|
||||
13, 27, 47, 53, 53, 47, 27, 13,
|
||||
13, 27, 53, 60, 60, 53, 27, 13,
|
||||
13, 27, 53, 60, 60, 53, 27, 13,
|
||||
13, 27, 47, 53, 53, 47, 27, 13,
|
||||
13, 13, 33, 33, 33, 33, 13, 13,
|
||||
0, 13, 13, 13, 13, 13, 13, 0 }
|
||||
};
|
||||
|
||||
#endif /* INCLUDE_GUARD_TYPES_H */
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#include <vector>
|
||||
#include <Rcpp.h>
|
||||
|
||||
#include "SchachHoernchen/Move.h"
|
||||
#include "SchachHoernchen/Board.h"
|
||||
|
||||
//' Human Crafted Evaluation
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::NumericVector HCE(const std::vector<Board>& positions) {
|
||||
// Iterate all positions and call the static board evaluation
|
||||
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) {
|
||||
return (double)pos.evaluate() / 100.0;
|
||||
}
|
||||
);
|
||||
}
|
|
@ -21,9 +21,59 @@ BEGIN_RCPP
|
|||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// isWhiteTurn
|
||||
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions);
|
||||
RcppExport SEXP _Rchess_isWhiteTurn(SEXP positionsSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(isWhiteTurn(positions));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// isCheck
|
||||
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions);
|
||||
RcppExport SEXP _Rchess_isCheck(SEXP positionsSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(isCheck(positions));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// isQuiet
|
||||
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions);
|
||||
RcppExport SEXP _Rchess_isQuiet(SEXP positionsSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(isQuiet(positions));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// isTerminal
|
||||
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions);
|
||||
RcppExport SEXP _Rchess_isTerminal(SEXP positionsSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(isTerminal(positions));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// isInsufficient
|
||||
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions);
|
||||
RcppExport SEXP _Rchess_isInsufficient(SEXP positionsSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(isInsufficient(positions));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// data_gen
|
||||
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count);
|
||||
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP) {
|
||||
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
|
||||
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||
|
@ -33,18 +83,22 @@ BEGIN_RCPP
|
|||
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
|
||||
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count));
|
||||
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
// eval_psqt
|
||||
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt);
|
||||
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP) {
|
||||
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt, const bool pawn_structure, const bool eval_rooks, const bool eval_king);
|
||||
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP, SEXP pawn_structureSEXP, SEXP eval_rooksSEXP, SEXP eval_kingSEXP) {
|
||||
BEGIN_RCPP
|
||||
Rcpp::RObject rcpp_result_gen;
|
||||
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||
Rcpp::traits::input_parameter< const std::vector<Rcpp::NumericMatrix>& >::type psqt(psqtSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt));
|
||||
Rcpp::traits::input_parameter< const bool >::type pawn_structure(pawn_structureSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type eval_rooks(eval_rooksSEXP);
|
||||
Rcpp::traits::input_parameter< const bool >::type eval_king(eval_kingSEXP);
|
||||
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt, pawn_structure, eval_rooks, eval_king));
|
||||
return rcpp_result_gen;
|
||||
END_RCPP
|
||||
}
|
||||
|
@ -196,8 +250,13 @@ END_RCPP
|
|||
|
||||
static const R_CallMethodDef CallEntries[] = {
|
||||
{"_Rchess_HCE", (DL_FUNC) &_Rchess_HCE, 1},
|
||||
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 6},
|
||||
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 2},
|
||||
{"_Rchess_isWhiteTurn", (DL_FUNC) &_Rchess_isWhiteTurn, 1},
|
||||
{"_Rchess_isCheck", (DL_FUNC) &_Rchess_isCheck, 1},
|
||||
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
|
||||
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
|
||||
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
|
||||
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
|
||||
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
|
||||
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
|
||||
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
|
||||
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
#include <vector>
|
||||
#include <Rcpp.h>
|
||||
|
||||
#include "SchachHoernchen/Board.h"
|
||||
|
||||
//' Given a FEN (position) determines if its whites turn
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions) {
|
||||
// Iterate all positions and call the static board evaluation
|
||||
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) { return pos.isWhiteTurn(); }
|
||||
);
|
||||
}
|
||||
|
||||
//' Check if current side to move is in check
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions) {
|
||||
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) { return pos.isCheck(); }
|
||||
);
|
||||
}
|
||||
|
||||
//' Check if the current position is a quiet position (no piece is attacked)
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions) {
|
||||
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) { return pos.isQuiet(); }
|
||||
);
|
||||
}
|
||||
|
||||
//' Check if position is terminal
|
||||
//'
|
||||
//' Checks if the position is a terminal position, meaning if the game ended
|
||||
//' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||
//' checked, therefore a seperate game history is required which the board
|
||||
//' does NOT track.
|
||||
//'
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions) {
|
||||
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) { return pos.isTerminal(); }
|
||||
);
|
||||
}
|
||||
|
||||
//' Check if checkmate is possible by material on the board
|
||||
//'
|
||||
//' Checks if there is sufficient mating material on the board, meaning if it
|
||||
//' possible for any side to deliver a check mate. More specifically, it
|
||||
//' checks if the pieces on the board are KK, KNK or KBK.
|
||||
//'
|
||||
// [[Rcpp::export(rng = false)]]
|
||||
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions) {
|
||||
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||
[](const Board& pos) { return pos.isInsufficient(); }
|
||||
);
|
||||
}
|
|
@ -20,7 +20,8 @@ Rcpp::CharacterVector data_gen(
|
|||
const float score_min = -5.0,
|
||||
const float score_max = +5.0,
|
||||
const bool quiet = false,
|
||||
const int min_ply_count = 10
|
||||
const int min_ply_count = 10,
|
||||
const bool white_only = true
|
||||
) {
|
||||
// Check parames
|
||||
if (sample_size < 1) {
|
||||
|
@ -103,10 +104,10 @@ Rcpp::CharacterVector data_gen(
|
|||
}
|
||||
|
||||
// Reject / Filter samples
|
||||
if (((int)pos.plyCount() < min_ply_count) // Filter early positions
|
||||
|| (pos.sideToMove() == piece::black) // Filter white to move positions
|
||||
|| (score < score_min || score_max <= score) // filter scores out of slice
|
||||
|| (quiet && pos.isQuiet())) // filter quiet positions (iff requested)
|
||||
if (((int)pos.plyCount() < min_ply_count) // early positions
|
||||
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|
||||
|| (score < score_min || score_max <= score) // scores out of slice
|
||||
|| (quiet && !pos.isQuiet())) // quiet positions
|
||||
{
|
||||
reject_count++;
|
||||
continue;
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
#include <vector>
|
||||
#include <Rcpp.h>
|
||||
|
||||
#include "SchachHoernchen/Move.h"
|
||||
#include "SchachHoernchen/Board.h"
|
||||
|
||||
//' Human Crafted Evaluation
|
||||
// [[Rcpp::export(name = "eval.psqt", rng = false)]]
|
||||
Rcpp::NumericVector eval_psqt(
|
||||
const std::vector<Board>& positions,
|
||||
const std::vector<Rcpp::NumericMatrix>& psqt,
|
||||
const bool pawn_structure = false,
|
||||
const bool eval_rooks = false,
|
||||
const bool eval_king = false
|
||||
) {
|
||||
// validate Piece Square Table count and sizes
|
||||
if (psqt.size() != 6) {
|
||||
Rcpp::stop("Expected exactly 6 PSQTs");
|
||||
}
|
||||
for (const auto table : psqt) {
|
||||
if (table.nrow() != 8 || table.ncol() != 8) {
|
||||
Rcpp::stop("PSQT table missmatch, all expected to be `8 x 8`");
|
||||
}
|
||||
}
|
||||
|
||||
// create numeric vector by evaluating all positions
|
||||
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||
[&psqt, pawn_structure, eval_rooks, eval_king](
|
||||
const Board& pos
|
||||
) {
|
||||
// Index to color/piece mapping (more robust)
|
||||
enum piece colorLoopup[2] = { white, black };
|
||||
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
|
||||
|
||||
// Score is the "inner product" of the "one-hot encoded" position
|
||||
// and the piece square tables (PSQT)
|
||||
double whiteMaterial = 0.0, blackMaterial = 0.0;
|
||||
for (int piece = 0; piece < 6; ++piece) {
|
||||
u64 piece_bb = pos.bb(pieceLookup[piece]);
|
||||
// First the White (positive) pieces
|
||||
for (u64 bb = pos.bb(piece::white) & piece_bb; bb; bb &= bb - 1) {
|
||||
// Get piece on bitboard index (Least Significant Bit)
|
||||
int index = bitScanLS(bb);
|
||||
// Transpose to align with PSQT memory layout
|
||||
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||
whiteMaterial += psqt[piece][index];
|
||||
}
|
||||
// Second the black (negative) pieces (with flipped Ranks)
|
||||
for (u64 bb = pos.bb(piece::black) & piece_bb; bb; bb &= bb - 1) {
|
||||
// Get fliped board index
|
||||
int index = bitScanLS(bb);
|
||||
// Transpose to align with PSQT memory layout and flip ranks
|
||||
// convert from whites perspective to blacks persepective
|
||||
index = ((index & 7) << 3) | (7 - ((index & 56) >> 3));
|
||||
blackMaterial += psqt[piece][index];
|
||||
}
|
||||
}
|
||||
return (whiteMaterial - blackMaterial) / 100.0;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
devtools::load_all()
|
||||
save_point <- sort(list.files(
|
||||
"~/Work/tensorPredictors/dataAnalysis/chess/",
|
||||
pattern = "save_point.*\\.Rdata",
|
||||
full.names = TRUE
|
||||
), decreasing = TRUE)[[1]]
|
||||
load(save_point)
|
||||
|
||||
psqt <- Map(function(parts) matrix(rowSums(kronecker(parts[[2]], parts[[1]])), 8, 8), betas)
|
||||
psqt <- Map(`-`, psqt[1:6], Map(function(table) table[8:1, ], psqt[7:12]))
|
||||
|
||||
eval.psqt("startpos", psqt)
|
||||
|
||||
*/
|
|
@ -15,8 +15,12 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
|
|||
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
|
||||
bitboards.attr("dim") = dims;
|
||||
bitboards.attr("dimnames") = Rcpp::List::create(
|
||||
Rcpp::Named("rank") = Rcpp::CharacterVector::create("8", "7", "6", "5", "4", "3", "2", "1"),
|
||||
Rcpp::Named("file") = Rcpp::CharacterVector::create("a", "b", "c", "d", "e", "f", "g", "h"),
|
||||
Rcpp::Named("rank") = Rcpp::CharacterVector::create(
|
||||
"8", "7", "6", "5", "4", "3", "2", "1"
|
||||
),
|
||||
Rcpp::Named("file") = Rcpp::CharacterVector::create(
|
||||
"a", "b", "c", "d", "e", "f", "g", "h"
|
||||
),
|
||||
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
|
||||
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
|
||||
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
|
||||
|
@ -42,6 +46,8 @@ Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
|
|||
int index = bitScanLS(bb);
|
||||
// Transpose to align with printing as a Chess Board
|
||||
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||
// Flip black to move positions to whites point of view
|
||||
index ^= pos.isWhiteTurn() ? 0 : 7;
|
||||
bitboards[768 * i + 64 * slice + index] = 1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
// #include <Rcpp.h>
|
||||
|
||||
// #include "SchachHoernchen/utils.h"
|
||||
// #include "SchachHoernchen/Board.h"
|
||||
// #include "SchachHoernchen/uci.h"
|
||||
|
||||
// // [[Rcpp::export(name = "print.board", rng = false)]]
|
||||
// void print_board(
|
||||
// const Board& board,
|
||||
// const bool check = true,
|
||||
// const bool attacked = false,
|
||||
// const bool pinned = false,
|
||||
// const bool checkers = false
|
||||
// ) {
|
||||
// using Rcpp::Rcout;
|
||||
|
||||
// // Extract some properties
|
||||
// piece color = board.sideToMove();
|
||||
// piece enemy = static_cast<piece>(!color);
|
||||
// u64 empty = ~(board.bb(piece::white) | board.bb(piece::black));
|
||||
// u64 cKing = board.bb(king) & board.bb(color);
|
||||
// // Construct highlight mask
|
||||
// u64 attackMask = board.attacks(enemy, empty);
|
||||
// u64 mask = 0;
|
||||
// mask |= check ? attackMask & board.bb(color) & board.bb(king) : 0;
|
||||
// mask |= attacked ? attackMask & ~board.bb(color) : 0;
|
||||
// mask |= pinned ? board.pinned(enemy, cKing, empty) : 0;
|
||||
// mask |= checkers ? board.checkers(enemy, cKing, empty) : 0;
|
||||
|
||||
// // print the board to console
|
||||
// Rcout << "FEN: " << board.fen() << '\n';
|
||||
// for (Index line = 0; line < 17; line++) {
|
||||
// if (line % 2) {
|
||||
// Index rankIndex = line / 2;
|
||||
// Rcout << static_cast<char>('8' - rankIndex);
|
||||
// for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
|
||||
// Index squareIndex = 8 * rankIndex + fileIndex;
|
||||
// if (bitMask<Square>(squareIndex) & mask) {
|
||||
// if (board.piece(squareIndex)) {
|
||||
// if (board.color(squareIndex) == black) {
|
||||
// // bold + italic + underline + black (blue)
|
||||
// Rcout << " | \033[1;3;4;94m";
|
||||
// } else {
|
||||
// // bold + italic + underline (+ white)
|
||||
// Rcout << " | \033[1;3;4m";
|
||||
// }
|
||||
// Rcout << UCI::formatPiece(board.piece(squareIndex))
|
||||
// << "\033[0m";
|
||||
// } else {
|
||||
// Rcout << " | .";
|
||||
// }
|
||||
// } else if (board.color(squareIndex) == black) {
|
||||
// Rcout << " | \033[1m\033[94m" // bold + blue (black)
|
||||
// << UCI::formatPiece(board.piece(squareIndex))
|
||||
// << "\033[0m";
|
||||
// } else if (board.color(squareIndex) == white) {
|
||||
// Rcout << " | \033[1m\033[97m" // bold + white
|
||||
// << UCI::formatPiece(board.piece(squareIndex))
|
||||
// << "\033[0m";
|
||||
// } else {
|
||||
// Rcout << " | ";
|
||||
// }
|
||||
// }
|
||||
// Rcout << " |";
|
||||
// } else {
|
||||
// Rcout << " +---+---+---+---+---+---+---+---+";
|
||||
// }
|
||||
// Rcout << "\033[0K\n"; // clear rest of line (remove potential leftovers)
|
||||
// }
|
||||
// Rcout << " a b c d e f g h" << std::endl;
|
||||
// }
|
|
@ -1,8 +1,6 @@
|
|||
options(keep.source = TRUE, keep.source.pkgs = TRUE)
|
||||
|
||||
library(tensorPredictors)
|
||||
|
||||
# Load as 3D predictors `X` and flat response `y`
|
||||
# Load as 3D predictors `X` and flat response `y` and `F = y` with per person dim. 1 x 1
|
||||
c(X, F, y) %<-% local({
|
||||
# Load from file
|
||||
ds <- readRDS("eeg_data.rds")
|
||||
|
@ -30,42 +28,7 @@ c(X, F, y) %<-% local({
|
|||
# fit a tensor normal model to the data sample axis 1 indexes persons)
|
||||
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
|
||||
|
||||
# Performa a LOO prediction
|
||||
y.hat <- sapply(seq_len(dim(X)[1L]), function(i) {
|
||||
# Fit with i'th observation removes
|
||||
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
|
||||
|
||||
# Reduce the entire data set
|
||||
r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
|
||||
# Fit a logit model on reduced data with i'th observation removed
|
||||
logit <- glm(y ~ r, family = binomial(link = "logit"),
|
||||
data = data.frame(y = y[-i], r = r[-i])
|
||||
)
|
||||
# predict i'th response given i'th reduced observation
|
||||
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
|
||||
# report progress
|
||||
cat(sprintf("%3d/%d\n", i, dim(X)[1L]))
|
||||
|
||||
y.hat
|
||||
})
|
||||
|
||||
### Classification performance measures
|
||||
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
|
||||
(acc <- mean(round(y.hat) == y)) # 0.7868852
|
||||
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
|
||||
(err <- mean(round(y.hat) != y)) # 0.2131148
|
||||
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
|
||||
(fpr <- mean((round(y.hat) == 1)[y == 0])) # 0.4
|
||||
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
|
||||
(tpr <- mean((round(y.hat) == 1)[y == 1])) # 0.8961039
|
||||
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
|
||||
(fnr <- mean((round(y.hat) == 0)[y == 1])) # 0.1038961
|
||||
# tnr: True negative rate. P(Yhat = - | Y = -).
|
||||
(tnr <- mean((round(y.hat) == 0)[y == 0])) # 0.6
|
||||
# auc: Area Under the Curve
|
||||
(auc <- pROC::roc(y, y.hat, quiet = TRUE)$auc) # 0.838961
|
||||
|
||||
|
||||
# plot the fitted mode wise reductions (for time and sensor axis)
|
||||
with(fit.gmlm, {
|
||||
par.reset <- par(mfrow = c(2, 1))
|
||||
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
|
||||
|
@ -74,20 +37,83 @@ with(fit.gmlm, {
|
|||
})
|
||||
|
||||
|
||||
# dimX <- c(4, 3, 5)
|
||||
# Omegas <- Map(function(p) {
|
||||
# O <- matrix(rnorm(p^2), p)
|
||||
# O %*% t(O)
|
||||
# }, dimX)
|
||||
#' (2D)^2 PCA preprocessing
|
||||
#'
|
||||
#' @param tpc Number of "t"ime "p"rincipal "c"omponents.
|
||||
#' @param ppc Number of "p"redictor "p"rincipal "c"omponents.
|
||||
preprocess <- function(X, tpc, ppc) {
|
||||
# Mode covariances (for predictor and time point modes)
|
||||
c(Sigma_t, Sigma_p) %<-% mcov(X, sample.axis = 1L)
|
||||
|
||||
# # Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
||||
# # which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
||||
# # `Omega <- Reduce(kronecker, rev(Omegas))`.
|
||||
# "predictor" (sensor) and time point principal components
|
||||
V_t <- svd(Sigma_t, tpc, 0L)$u
|
||||
V_p <- svd(Sigma_p, ppc, 0L)$u
|
||||
|
||||
# Omega <- Reduce(kronecker, rev(Omegas))
|
||||
# log(det(Omega)) / prod(nrow(Omega))
|
||||
# reduce with mode wise PCs
|
||||
mlm(X, list(V_t, V_p), modes = 2:3, transposed = TRUE)
|
||||
}
|
||||
|
||||
# (det.Omega <- sum(log(mapply(det, Omegas)) / dimX))
|
||||
# sum(mapply(function(Omega) {
|
||||
# sum(log(eigen(Omega, TRUE, TRUE)$values))
|
||||
# }, Omegas) / dimX)
|
||||
|
||||
#' Leave-one-out prediction
|
||||
#'
|
||||
#' @param X 3D EEG data (preprocessed or not)
|
||||
#' @param F binary responce `y` as a 3D tensor, every obs. is a 1 x 1 matrix
|
||||
loo.predict <- function(X, F) {
|
||||
sapply(seq_len(dim(X)[1L]), function(i) {
|
||||
# Fit with i'th observation removes
|
||||
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
|
||||
|
||||
# Reduce the entire data set
|
||||
r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
|
||||
# Fit a logit model on reduced data with i'th observation removed
|
||||
logit <- glm(y ~ r, family = binomial(link = "logit"),
|
||||
data = data.frame(y = y[-i], r = r[-i])
|
||||
)
|
||||
# predict i'th response given i'th reduced observation
|
||||
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
|
||||
# report progress
|
||||
cat(sprintf("dim: (%d, %d) - %3d/%d\n", dim(X)[2L], dim(X)[3L], i, dim(X)[1L]))
|
||||
|
||||
y.hat
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
### Classification performance measures
|
||||
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
|
||||
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
|
||||
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
|
||||
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
|
||||
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
|
||||
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
|
||||
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
|
||||
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
|
||||
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
|
||||
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
|
||||
# tnr: True negative rate. P(Yhat = - | Y = -).
|
||||
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
|
||||
# auc: Area Under the Curve
|
||||
auc <- function(y.true, y.pred) as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE)$auc)
|
||||
auc.sd <- function(y.true, y.pred) sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE)))
|
||||
|
||||
|
||||
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
|
||||
y.hat.3.4 <- loo.predict(preprocess(X, 3, 4), F)
|
||||
y.hat.15.15 <- loo.predict(preprocess(X, 15, 15), F)
|
||||
y.hat.20.30 <- loo.predict(preprocess(X, 20, 30), F)
|
||||
y.hat <- loo.predict(X, F)
|
||||
|
||||
# classification performance measures table by leave-one-out cross-validation
|
||||
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
|
||||
sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
|
||||
function(FUN) { match.fun(FUN)(y, y.pred) })
|
||||
}))
|
||||
#> y.hat.3.4 y.hat.15.15 y.hat.20.30 y.hat
|
||||
#> acc 0.79508197 0.78688525 0.78688525 0.78688525
|
||||
#> err 0.20491803 0.21311475 0.21311475 0.21311475
|
||||
#> fpr 0.35555556 0.40000000 0.40000000 0.40000000
|
||||
#> tpr 0.88311688 0.89610390 0.89610390 0.89610390
|
||||
#> fnr 0.11688312 0.10389610 0.10389610 0.10389610
|
||||
#> tnr 0.64444444 0.60000000 0.60000000 0.60000000
|
||||
#> auc 0.85108225 0.83838384 0.83924964 0.83896104
|
||||
#> auc.sd 0.03584791 0.03760531 0.03751307 0.03754553
|
||||
|
|
Loading…
Reference in New Issue