Compare commits
15 Commits
70b3cdaa5e
...
7a68948d26
Author | SHA1 | Date |
---|---|---|
Daniel Kapla | 7a68948d26 | |
Daniel Kapla | 8fd80522f0 | |
Daniel Kapla | 0cb7772132 | |
Daniel Kapla | db9c3f4794 | |
Daniel Kapla | 636eebf720 | |
Daniel Kapla | daefd3e7d1 | |
Daniel Kapla | 61bd94bec8 | |
Daniel Kapla | 13d3c63575 | |
Daniel Kapla | fa2a99f3f0 | |
Daniel Kapla | 85b6a2a12a | |
Daniel Kapla | 078c406100 | |
Daniel Kapla | 6792cf93a9 | |
Daniel Kapla | 4b4b30ceb0 | |
Daniel Kapla | 2e87d14696 | |
Daniel Kapla | 4324295162 |
|
@ -43,10 +43,11 @@
|
||||||
*.idb
|
*.idb
|
||||||
*.pdb
|
*.pdb
|
||||||
|
|
||||||
## R environment, data and pacakge build intermediate files/folders
|
## R environment, data and package build intermediate files/folders
|
||||||
# R Data Object files
|
# R Data Object files
|
||||||
*.Rds
|
*.Rds
|
||||||
*.rds
|
*.rds
|
||||||
|
*.Rdata
|
||||||
|
|
||||||
# Example code in package build process
|
# Example code in package build process
|
||||||
*-Ex.R
|
*-Ex.R
|
||||||
|
@ -108,20 +109,23 @@ simulations/
|
||||||
!**/LaTeX/*.bib
|
!**/LaTeX/*.bib
|
||||||
**/LaTeX/*-blx.bib
|
**/LaTeX/*-blx.bib
|
||||||
|
|
||||||
|
mlda_analysis/
|
||||||
|
References/
|
||||||
|
dataAnalysis/chess/*.Rdata
|
||||||
|
dataAnalysis/Classification of EEG/
|
||||||
|
|
||||||
|
*.csv
|
||||||
|
*.csv.log
|
||||||
|
|
||||||
# Include subfolders for images and plots
|
# Include subfolders for images and plots
|
||||||
!**/LaTeX/plots/
|
!**/LaTeX/plots/
|
||||||
**/LaTeX/plots/*
|
**/LaTeX/plots/*
|
||||||
!**/LaTeX/plots/*.tex
|
!**/LaTeX/plots/*.tex
|
||||||
|
!**/LaTeX/plots/*.csv
|
||||||
!**/LaTeX/images/
|
!**/LaTeX/images/
|
||||||
**/LaTeX/images/*
|
**/LaTeX/images/*
|
||||||
!**/LaTeX/images/*.tex
|
!**/LaTeX/images/*.tex
|
||||||
|
|
||||||
mlda_analysis/
|
|
||||||
References/
|
|
||||||
dataAnalysis/
|
|
||||||
*.csv
|
|
||||||
*.csv.log
|
|
||||||
|
|
||||||
# Images (except images used in LaTeX)
|
# Images (except images used in LaTeX)
|
||||||
*.png
|
*.png
|
||||||
*.svg
|
*.svg
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
\documentclass{standalone}
|
||||||
|
|
||||||
|
\usepackage{pgfplots} % TikZ (TeX ist kein Zeichenprogramm)
|
||||||
|
\usetikzlibrary{calc} % for vector arithmetics
|
||||||
|
|
||||||
|
\usepackage{amssymb, bm}
|
||||||
|
|
||||||
|
\renewcommand{\t}[1]{{#1}^{T}}
|
||||||
|
\newcommand{\mat}[1]{\boldsymbol{#1}}
|
||||||
|
\newcommand{\manifold}[1]{\mathfrak{#1}}
|
||||||
|
|
||||||
|
% PGF-Plot / TikZ config
|
||||||
|
\usetikzlibrary{%
|
||||||
|
calc, through, intersections, patterns, patterns.meta, pgfplots.colormaps
|
||||||
|
}
|
||||||
|
\pgfplotsset{
|
||||||
|
compat = newest,
|
||||||
|
colormap = {grayscale}{color=(lightgray) color=(white) color=(lightgray)},
|
||||||
|
colormap = {blackscale}{color=(black!70) color=(black!50) color=(black!70)},
|
||||||
|
colormap = {redscale}{color=(black!70!red) color=(black!50!red) color=(black!70!red)},
|
||||||
|
colormap = {bluescale}{color=(black!70!blue) color=(black!50!blue) color=(black!70!blue)},
|
||||||
|
}
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}[
|
||||||
|
>=latex,
|
||||||
|
scale = 1,
|
||||||
|
declare function = { % Note: NO spaces in function argument list!
|
||||||
|
X(\u,\v) = (2 + cos(\u)) * cos(\v);
|
||||||
|
Y(\u,\v) = (2 + cos(\u)) * sin(\v);
|
||||||
|
Z(\u,\v) = sin(\u);
|
||||||
|
tx(\u,\v,\x,\y) = - sin(\u) * cos(\u) * \x - (2 + cos(\u)) * sin(\v) * \y;
|
||||||
|
ty(\u,\v,\x,\y) = - sin(\u) * cos(\u) * \x + (2 + cos(\u)) * cos(\v) * \y;
|
||||||
|
tz(\u,\v,\x,\y) = + cos(\u) * \x;
|
||||||
|
}
|
||||||
|
]
|
||||||
|
\begin{axis}[
|
||||||
|
axis equal image,
|
||||||
|
hide axis,
|
||||||
|
view = {120}{30},
|
||||||
|
scale = 2
|
||||||
|
]
|
||||||
|
\addplot3[
|
||||||
|
surf,
|
||||||
|
shader = faceted interp,
|
||||||
|
samples = 20,
|
||||||
|
samples y = 40,
|
||||||
|
domain = 0:360,
|
||||||
|
domain y = 0:360,
|
||||||
|
z buffer = sort,
|
||||||
|
colormap name = grayscale,
|
||||||
|
thin
|
||||||
|
]
|
||||||
|
({X(\x, \y)}, {Y(\x, \y)}, {Z(\x, \y)});
|
||||||
|
|
||||||
|
% at = (1.433013, 2.482051, 0.5) // \u, \v = (30, 60)
|
||||||
|
% into = (-0.4330127, -0.4330127, 0.8660254)
|
||||||
|
% and = (-2.482051, 1.433013, 0)
|
||||||
|
|
||||||
|
% X <- c(1.433013, 2.482051, 0.5)
|
||||||
|
% dx <- c(-0.4330127, -0.4330127, 0.8660254)
|
||||||
|
% dy <- c(-2.482051, 1.433013, 0)
|
||||||
|
|
||||||
|
\addplot3[
|
||||||
|
mesh,
|
||||||
|
shader = interp,
|
||||||
|
patch type = line,
|
||||||
|
variable = t,
|
||||||
|
domain = -44:210.5,
|
||||||
|
samples = 64,
|
||||||
|
samples y = 1,
|
||||||
|
colormap name = redscale
|
||||||
|
]
|
||||||
|
({X(\t, \t + 30)}, {Y(\t, \t + 30)}, {Z(\t, \t + 30)});
|
||||||
|
|
||||||
|
\addplot3[
|
||||||
|
mesh,
|
||||||
|
shader = interp,
|
||||||
|
patch type = line,
|
||||||
|
variable = t,
|
||||||
|
domain = -50:119,
|
||||||
|
samples = 64,
|
||||||
|
samples y = 1,
|
||||||
|
colormap name = bluescale
|
||||||
|
]
|
||||||
|
({X(\t, -0.3 * \t + 69)}, {Y(\t, -0.3 * \t + 69)}, {Z(\t, -0.3 * \t + 69)});
|
||||||
|
|
||||||
|
\coordinate (x) at ({X(30, 60)}, {Y(30, 60)}, {Z(30, 60)});
|
||||||
|
|
||||||
|
\draw[dashed, fill = gray, opacity = 0.4] (
|
||||||
|
{X(30, 60) + tx(30, 60, 0.4 + 0.8, 0.4 - 0.24)},
|
||||||
|
{Y(30, 60) + ty(30, 60, 0.4 + 0.8, 0.4 - 0.24)},
|
||||||
|
{Z(30, 60) + tz(30, 60, 0.4 + 0.8, 0.4 - 0.24)}
|
||||||
|
) -- (
|
||||||
|
{X(30, 60) + tx(30, 60, -0.4 + 0.8, -0.4 - 0.24)},
|
||||||
|
{Y(30, 60) + ty(30, 60, -0.4 + 0.8, -0.4 - 0.24)},
|
||||||
|
{Z(30, 60) + tz(30, 60, -0.4 + 0.8, -0.4 - 0.24)}
|
||||||
|
) node[anchor = west, opacity = 1, outer sep=0.5em] {$T_{\mat{x}}\manifold{A}$} -- (
|
||||||
|
{X(30, 60) + tx(30, 60, -0.4 - 0.8, -0.4 + 0.24)},
|
||||||
|
{Y(30, 60) + ty(30, 60, -0.4 - 0.8, -0.4 + 0.24)},
|
||||||
|
{Z(30, 60) + tz(30, 60, -0.4 - 0.8, -0.4 + 0.24)}
|
||||||
|
) -- (
|
||||||
|
{X(30, 60) + tx(30, 60, 0.4 - 0.8, 0.4 + 0.24)},
|
||||||
|
{Y(30, 60) + ty(30, 60, 0.4 - 0.8, 0.4 + 0.24)},
|
||||||
|
{Z(30, 60) + tz(30, 60, 0.4 - 0.8, 0.4 + 0.24)}
|
||||||
|
) -- cycle;
|
||||||
|
|
||||||
|
\draw[->, black!50!red] (x) -- (
|
||||||
|
{X(30, 60) + tx(30, 60, 0.4, 0.4)},
|
||||||
|
{Y(30, 60) + ty(30, 60, 0.4, 0.4)},
|
||||||
|
{Z(30, 60) + tz(30, 60, 0.4, 0.4)}
|
||||||
|
) node[pos = 0.7, anchor = north west, inner sep = 0pt] {$\t{\nabla\gamma_1(0)}$};
|
||||||
|
|
||||||
|
\draw[->, black!50!blue] (x) -- (
|
||||||
|
{X(30, 60) + tx(30, 60, 0.8, -0.24)}, % -0.24 = 0.8 * -0.3
|
||||||
|
{Y(30, 60) + ty(30, 60, 0.8, -0.24)},
|
||||||
|
{Z(30, 60) + tz(30, 60, 0.8, -0.24)}
|
||||||
|
) node[pos = 0.7, anchor = south west, inner sep = 0pt] {$\t{\nabla\gamma_2(0)}$};
|
||||||
|
|
||||||
|
\node[anchor = north] at (x) {$\mat{x}$};
|
||||||
|
\node[circle, inner sep={1pt}, outer sep={0pt}, fill=black] at (x) {};
|
||||||
|
|
||||||
|
\end{axis}
|
||||||
|
\end{tikzpicture}
|
||||||
|
\end{document}
|
|
@ -0,0 +1,118 @@
|
||||||
|
\documentclass{standalone}
|
||||||
|
|
||||||
|
\usepackage{pgfplots} % TikZ (TeX ist kein Zeichenprogramm)
|
||||||
|
\usetikzlibrary{calc, perspective, pgfplots.colormaps} % PGF-Plot / TikZ config
|
||||||
|
|
||||||
|
\pgfplotsset{
|
||||||
|
compat = newest,
|
||||||
|
colormap = {grayscale}{color=(lightgray) color=(white) color=(lightgray)},
|
||||||
|
colormap = {blackscale}{color=(black!70) color=(black!50) color=(black!70)},
|
||||||
|
}
|
||||||
|
|
||||||
|
% Define the (component) embedding into the torus
|
||||||
|
\tikzset{declare function = { % Note: NO spaces in function argument list!
|
||||||
|
Z(\u,\v) = 0.4 * \u * \u * cos(\v * 120);
|
||||||
|
bx(\t) = -0.5 + 0.3 * cos(\t) + 0.05 * sin(3 * \t);
|
||||||
|
by(\t) = 0.2 + 0.3 * sin(\t);
|
||||||
|
}}
|
||||||
|
|
||||||
|
% Further packages and macros
|
||||||
|
\usepackage{amssymb, bm}
|
||||||
|
|
||||||
|
\renewcommand{\t}[1]{{#1}^{T}}
|
||||||
|
\newcommand{\mat}[1]{\boldsymbol{#1}}
|
||||||
|
\newcommand{\manifold}[1]{\mathfrak{#1}}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}[>=latex]
|
||||||
|
|
||||||
|
\begin{axis}[
|
||||||
|
axis equal image,
|
||||||
|
hide axis,
|
||||||
|
view = {120}{30},
|
||||||
|
scale = 1,
|
||||||
|
clip = false
|
||||||
|
]
|
||||||
|
\coordinate (O) at (0, 0, 0);
|
||||||
|
|
||||||
|
\draw[->] (-0.1, 0, 0) -- (0.6, 0, 0) node[pos = 1.1] {};
|
||||||
|
\draw[->] (0, -0.1, 0) -- (0, 1.2, 0) node[pos = 1.1] {};
|
||||||
|
\draw[->] (0, 0, -0.1) -- (0, 0, 1.0) node[pos = 1.1] {};
|
||||||
|
|
||||||
|
\addplot3[
|
||||||
|
surf,
|
||||||
|
shader = faceted interp,
|
||||||
|
samples = 16,
|
||||||
|
samples y = 16,
|
||||||
|
domain = -1:0.4,
|
||||||
|
domain y = -0.5:1,
|
||||||
|
z buffer = sort,
|
||||||
|
colormap name = grayscale,
|
||||||
|
thin
|
||||||
|
]
|
||||||
|
({\x}, {\y}, {Z(\x, \y)});
|
||||||
|
|
||||||
|
\addplot3[
|
||||||
|
samples = 64,
|
||||||
|
samples y = 0,
|
||||||
|
domain = 0:360,
|
||||||
|
color = black!40!gray,
|
||||||
|
fill = black,
|
||||||
|
fill opacity = 0.1,
|
||||||
|
colormap name = blackscale,
|
||||||
|
thick
|
||||||
|
]
|
||||||
|
({bx(\x)}, {by(\x)}, {Z(bx(\x), by(\x))});
|
||||||
|
|
||||||
|
\coordinate (coordU) at ({bx(150)}, {by(150)}, {Z(bx(150), by(150))});
|
||||||
|
|
||||||
|
\node[anchor = south west] (U) at (coordU) {$U$};
|
||||||
|
|
||||||
|
\node[
|
||||||
|
circle, fill=black, inner sep=0.75pt, label={$\mat{\theta}_0$}
|
||||||
|
] (theta0) at (-0.5, 0.2, {Z(-0.5, 0.2)}) {};
|
||||||
|
|
||||||
|
\node at (0, 0.5, 1.2) {$\Theta\subseteq\mathbb{R}^p$};
|
||||||
|
|
||||||
|
\node (UU) at ({bx(0)}, {by(0)}, {Z(bx(0), by(0))}) {};
|
||||||
|
|
||||||
|
\end{axis}
|
||||||
|
|
||||||
|
\begin{scope}[shift = {(11cm, 2cm)}, scale = 2.5]
|
||||||
|
|
||||||
|
\coordinate (O) at (-1.1, -0.3);
|
||||||
|
|
||||||
|
\draw[step=0.1, lightgray!80, thin] (O) grid +(1.2, 1.2);
|
||||||
|
|
||||||
|
\draw[->] ($(O) - (0.05, 0)$) -- +(1.4, 0) node[pos=1.1] {};
|
||||||
|
\draw[->] ($(O) - (0, 0.05)$) -- +(0, 1.4) node[pos=1.1] {};
|
||||||
|
|
||||||
|
\draw[domain=0:360, smooth, variable=\x, fill=black, fill opacity = 0.1, thick] plot ({bx(\x)}, {by(\x)});
|
||||||
|
|
||||||
|
\node[
|
||||||
|
circle, fill=black, inner sep=0.75pt, label={$\mat{s}_0$}
|
||||||
|
] (s0) at (-0.5, 0.2) {};
|
||||||
|
|
||||||
|
\coordinate (coordPhiU) at ({bx(90)}, {by(90)});
|
||||||
|
|
||||||
|
\node[anchor = south east, outer sep = 0pt] (phiU) at (coordPhiU) {$\varphi(U)$};
|
||||||
|
|
||||||
|
\node at (-0.5, 1.28) {$\varphi(U)\subseteq\mathbb{R}^d$};
|
||||||
|
|
||||||
|
\node (phiUU) at ({bx(270)}, {by(270)}) {};
|
||||||
|
|
||||||
|
\end{scope}
|
||||||
|
|
||||||
|
|
||||||
|
\draw[->, out = 20, in = 160] (U.north east) to node[above, pos = 0.5] {$\varphi$} (phiU.north west);
|
||||||
|
\draw[->, out = 200, in = 340] (phiU.south west) to node[above, pos = 0.5] {$\varphi^{-1}$} (U.south east);
|
||||||
|
|
||||||
|
\node (R) at (6.1, 0) {$\mathbb{R}$};
|
||||||
|
|
||||||
|
\draw[->, out = 270, in = 180] (UU) to node[below left, pos = 0.6] {$M$} (R);
|
||||||
|
\draw[->, out = 270, in = 0] (phiUU) to node[below right, pos = 0.4] {$M_{\varphi}$} (R);
|
||||||
|
|
||||||
|
\end{tikzpicture}
|
||||||
|
\end{document}
|
|
@ -0,0 +1,78 @@
|
||||||
|
\documentclass{standalone}
|
||||||
|
|
||||||
|
\usepackage[LSB, T1]{fontenc}
|
||||||
|
\usepackage{chessboard}
|
||||||
|
\usepackage{skak}
|
||||||
|
\usepackage{tikz, tikz-3dplot}
|
||||||
|
\usepackage{amsmath}
|
||||||
|
\usepackage{xcolor}
|
||||||
|
|
||||||
|
\newcommand{\z}{{\color{gray}0}}
|
||||||
|
|
||||||
|
\tdplotsetmaincoords{80}{135}
|
||||||
|
|
||||||
|
\setboardfontencoding{LSB}
|
||||||
|
|
||||||
|
\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
|
||||||
|
|
||||||
|
\newcommand{\chessplane}[2]{
|
||||||
|
\begin{scope}[canvas is yz plane at x={-#1 * 0.8}, transform shape]
|
||||||
|
\node[fill = white, opacity = 0.7, outer sep=0pt, inner sep=2pt] (layer#1) at (0, 0) {
|
||||||
|
\chessboard[
|
||||||
|
margin=false,
|
||||||
|
pgfstyle=text,
|
||||||
|
text=\textbf{1},
|
||||||
|
markfields={#2},
|
||||||
|
label=false
|
||||||
|
]
|
||||||
|
};
|
||||||
|
\end{scope}
|
||||||
|
}
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}
|
||||||
|
|
||||||
|
\begin{scope}[tdplot_main_coords, scale = 1]
|
||||||
|
\chessplane{12}{e8};
|
||||||
|
\chessplane{11}{d8};
|
||||||
|
\chessplane{10}{a8, h8};
|
||||||
|
\chessplane{9}{c8, f8};
|
||||||
|
\chessplane{8}{b8, g8};
|
||||||
|
\chessplane{7}{a7, b7, c7, d7, e5, f7, g7, h7};
|
||||||
|
\chessplane{6}{e1};
|
||||||
|
\chessplane{5}{d1};
|
||||||
|
\chessplane{4}{a1, h1};
|
||||||
|
\chessplane{3}{c1, f1};
|
||||||
|
\chessplane{2}{c3, g1};
|
||||||
|
\chessplane{1}{a2, b2, c2, d2, e4, f2, g2, h2};
|
||||||
|
|
||||||
|
\begin{scope}[canvas is yz plane at x={-1}, transform shape]
|
||||||
|
\node[anchor = south, rotate = 90] at (layer1.west) {Ranks / Axis 1};
|
||||||
|
\node[anchor = north] at (layer1.south) {Files / Axis 2};
|
||||||
|
\end{scope}
|
||||||
|
|
||||||
|
\coordinate (offset) at (layer1.west);
|
||||||
|
\newdimen\xoffset
|
||||||
|
\pgfextractx{\xoffset}{\pgfpointanchor{offset}{center}}
|
||||||
|
\begin{scope}[canvas is xz plane at y=\xoffset, transform shape, xscale=-1]
|
||||||
|
\path (layer1.north west) -- (layer12.north west) node[
|
||||||
|
pos = 0.5, anchor = south
|
||||||
|
] {Pieces / Mixture Components};
|
||||||
|
\end{scope}
|
||||||
|
\end{scope}
|
||||||
|
|
||||||
|
\coordinate (tensor north) at (current bounding box.north);
|
||||||
|
|
||||||
|
\node[shift = {(0, 0)}, anchor = east] (pos) at (current bounding box.west) {{
|
||||||
|
\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
|
||||||
|
\newgame
|
||||||
|
% The Vienna Game
|
||||||
|
\hidemoves{1. e4 e5 2.Nc3} % Like `\mainline` but does NOT show the PGN line
|
||||||
|
\chessboard{}
|
||||||
|
}};
|
||||||
|
|
||||||
|
\node[anchor = south] at (pos.center |- tensor north) {Position};
|
||||||
|
\node[anchor = south] at (tensor north) {Encoding};
|
||||||
|
\end{tikzpicture}
|
||||||
|
|
||||||
|
\end{document}
|
|
@ -0,0 +1,36 @@
|
||||||
|
\begin{tikzpicture}[scale = \tikzscale], line width = 1pt]
|
||||||
|
|
||||||
|
\def\rect#1#2#3{
|
||||||
|
\draw (0, 0, 0) -- (#1, 0, 0) -- (#1, #2, 0) -- (0, #2, 0) -- cycle;
|
||||||
|
|
||||||
|
\draw[ ] (#1, 0, -#3) -- (#1, #2, -#3) -- (0, #2, -#3);
|
||||||
|
\draw[dashed] (#1, 0, -#3) -- (0, 0, -#3) -- (0, #2, -#3);
|
||||||
|
|
||||||
|
\draw[dashed] (0, 0, 0) -- (0, 0, -#3);
|
||||||
|
\draw[ ] (0, #2, 0) -- (0, #2, -#3);
|
||||||
|
\draw[ ] (#1, 0, 0) -- (#1, 0, -#3);
|
||||||
|
\draw[ ] (#1, #2, 0) -- (#1, #2, -#3);
|
||||||
|
}
|
||||||
|
|
||||||
|
\begin{scope}[yshift = 1cm, line width = 1pt]
|
||||||
|
\rect{1.5}{1}{2}
|
||||||
|
\node[font = \boldmath] at (1, 0.5) {$\ten{R}(\ten{X})$};
|
||||||
|
\end{scope}
|
||||||
|
\rect{3}{2}{4}
|
||||||
|
\node at (2, 0.5) {$\ten{X} - \E\ten{X}$};
|
||||||
|
|
||||||
|
\draw[lightgray, line width = 0.7pt] (-2.1, 2) arc (180:270:2);
|
||||||
|
\draw[fill = lightgray, fill opacity = 0.7] (-2.1, 1) rectangle +(2, 1)
|
||||||
|
node [pos = 0.5] {$\t{\mat{\beta}_1}$};
|
||||||
|
|
||||||
|
\draw[lightgray, line width = 0.7pt, domain = 0:1, smooth, variable = \t]
|
||||||
|
plot ({0}, {2.1 + 4 * cos(90 * \t)}, {-4 * sin(90 * \t)});
|
||||||
|
\draw[fill = lightgray, fill opacity = 0.7]
|
||||||
|
(0, 2.1, 0) -- (0, 2.1, -2) -- (0, 6.1, -2) -- (0, 6.1, 0) -- cycle;
|
||||||
|
\node[opacity = 0.7, cm={0.66, 0.66, 0, 1, (0, 0)}]
|
||||||
|
at (0, 4.1, -1.1) {$\t{\mat{\beta}_3}$};
|
||||||
|
|
||||||
|
\draw[lightgray, line width = 0.7pt] (0, 5.1) arc (90:0:3);
|
||||||
|
\draw[fill = lightgray, fill opacity = 0.7] (0, 2.1) rectangle +(1.5, 3)
|
||||||
|
node [pos = 0.5] {$\t{\mat{\beta}_2}$};
|
||||||
|
\end{tikzpicture}
|
2573
LaTeX/main.bib
2573
LaTeX/main.bib
File diff suppressed because it is too large
Load Diff
872
LaTeX/paper.tex
872
LaTeX/paper.tex
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.5 0.03385 0.2433541701 0.2433541701 0.00226 0.863261257 0.863261257 0.00132 0.961234054 0.961234054 0.01148 0.27002415 0.27002415 0.02002 0.726356935 0.726356935
|
||||||
|
200 50.5 0.05152 0.1649899379 0.1649899379 0.00217 0.852767026 0.852767026 0.00266 0.960479878 0.960479878 0.01373 0.1884941138 0.1884941138 0.01991 0.713602535 0.713602535
|
||||||
|
300 50.5 0.06495 0.1336658703 0.1336658703 0.00486 0.843714531 0.843714531 0.01064 0.958630831 0.958630831 0.01328 0.1526304709 0.1526304709 0.02 0.71557788 0.71557788
|
||||||
|
500 50.5 0.13549 0.1032317816 0.1032317816 0.00938 0.846591187 0.846591187 0.01447 0.959566069 0.959566069 0.01722 0.1208618464 0.1208618464 0.05966 0.713613799 0.713613799
|
||||||
|
750 50.5 0.19323 0.0920445927 0.0920445927 0.00751 0.843049644 0.843049644 0.02065 0.95998194 0.95998194 0.02132 0.1040318623 0.1040318623 0.06884 0.708925464 0.708925464
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.5 0.14271 0.1354623792 0.1798569269 0.0025 0.805020847 0.999758145 0.00145 0.761295768 0.954043725 0.0099 0.1994699895 0.265283655 0.56208 0.770351142 0.999521275
|
||||||
|
200 50.5 0.31697 0.091214259 0.1210839439 0.00241 0.802922368 0.999879665 0.00257 0.730120539 0.925368399 0.0111 0.1483718847 0.1994800746 0.59856 0.758579776 0.999522252
|
||||||
|
300 50.5 0.39396 0.0726682553 0.0971497767 0.00586 0.806481505 0.999904273 0.00971 0.749474609 0.94434003 0.01265 0.1159595746 0.1577793427 0.56191 0.761351696 0.999435422
|
||||||
|
500 50.5 0.83078 0.0580925082 0.0779289812 0.00612 0.80574294 0.999944057 0.01732 0.751571927 0.947305126 0.01541 0.0917985004 0.12450158 2.16411 0.755277454 0.999081992
|
||||||
|
750 50.5 1.34147 0.0461519857 0.0610860564 0.00908 0.805450073 0.999961492 0.02064 0.749556315 0.945706618 0.02024 0.0797286457 0.1077928847 2.6289 0.750454219 0.999378091
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm time.pca dist.subspace.pca time.hopca dist.subspace.hopca time.tsir dist.subspace.tsir time.mgcca dist.subspace.mgcca
|
||||||
|
100 50.5 0.53037 0.225507724 0.00225 0.268624529 0.0012 0.268631852 0.00927 0.975552785 0.01707 0.268634801
|
||||||
|
200 50.5 0.92916 0.1462304967 0.00229 0.268618192 0.00247 0.268627434 0.01055 0.969086651 0.0149 0.268616289
|
||||||
|
300 50.5 1.28251 0.118119137 0.00404 0.268672296 0.01078 0.268679546 0.01247 0.937344184 0.0181 0.268678444
|
||||||
|
500 50.5 2.50673 0.0925056109 0.00548 0.268620612 0.01687 0.268627137 0.01547 0.920788974 0.04591 0.268609548
|
||||||
|
750 50.5 3.78846 0.0734930824 0.00654 0.268629405 0.02286 0.26863585 0.0202 0.895542158 0.05844 0.268624504
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.3838383838384 0.19140404040404 0.159429996868687 0.217885249393939 0.00227272727272727 0.81991723030303 0.999889914141414 0.00136363636363636 0.7542142 0.956573433333333 0.00942424242424242 0.312921257070707 0.425537564949495 0.260232323232323 0.763839525252525 0.998650488888889
|
||||||
|
200 50.5 0.29272 0.1106441325 0.1502575874 0.00215 0.817273094 0.999952668 0.00267 0.756428697 0.961631268 0.01125 0.2406254633 0.3291691439 0.25668 0.752812085 0.997789874
|
||||||
|
300 50.5 0.40969 0.0942165171 0.1284457744 0.00577 0.817365328 0.99997193 0.00976 0.755867381 0.962490478 0.01219 0.2079868565 0.287219385 0.25407 0.744109244 0.997966974
|
||||||
|
500 50.5 0.80715 0.0671783369 0.092410679 0.00584 0.815885787 0.99998542 0.01576 0.745753387 0.955364336 0.01536 0.1719722289 0.2389028516 0.99684 0.74347681 0.998687779
|
||||||
|
750 50.5 1.30069 0.0560128971 0.077140984 0.00736 0.817057557 0.999983813 0.02158 0.745917278 0.955524558 0.01902 0.1544664904 0.2144889183 1.26885 0.742632193 0.998826916
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.5 0.02532 0.736796092 1 0.00209 0.900231255 0.999317475 0.00057 0.844797856 1 0.01269 0.768475572 1 0.03429 0.801702356 0.996195093
|
||||||
|
200 50.5 0.02118 0.723026339 1 0.00202 0.897808634 0.999511768 0.00076 0.821791902 1 0.01143 0.729274976 1 0.03156 0.754417729 0.993998159
|
||||||
|
300 50.5 0.02073 0.719268149 1 0.00253 0.896381124 0.999790203 0.00086 0.810580445 1 0.0108 0.723298879 1 0.03314 0.730024383 0.994609905
|
||||||
|
500 50.5 0.06797 0.714404844 1 0.00556 0.894836313 0.999794819 0.00472 0.806682723 1 0.01003 0.719358531 1 0.08743 0.70385003 0.990669984
|
||||||
|
750 50.5 0.08089 0.712524515 1 0.00652 0.894548046 0.999845024 0.00646 0.800351207 1 0.011 0.716469876 1 0.10204 0.687227594 0.991626558
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm time.tnormal dist.subspace.tnormal time.pca dist.subspace.pca time.hopca dist.subspace.hopca time.lpca dist.subspace.lpca time.clpca dist.subspace.clpca time.tsir dist.subspace.tsir time.mgcca dist.subspace.mgcca
|
||||||
|
100 50.5 3.45133 0.337616233 -1 0.3481195602 0.00072 0.895240418 0.00078 0.895726045 0.03836 0.940902054 0.01723 0.905975108 0.01174 0.4840146138 0.01181 0.552492798
|
||||||
|
200 50.5 3.53892 0.2469025607 -1 0.258109693 0.00072 0.899078549 0.00073 0.898430313 0.0454 0.96077596 0.0172 0.909073167 0.01033 0.3767345117 0.00636 0.529634143
|
||||||
|
300 50.5 3.32198 0.1973335545 -1 0.2202339413 7e-04 0.894730117 0.00099 0.89457278 0.04582 0.96650364 0.01672 0.906092011 0.01085 0.2945936124 0.00647 0.506940794
|
||||||
|
500 50.5 3.47776 0.1560136925 -1 0.178495357 0.00081 0.897138615 0.00142 0.896433003 0.05409 0.982097991 0.01885 0.907027647 0.01102 0.2348182831 0.00751 0.504727342
|
||||||
|
750 50.5 3.63432 0.1338283596 -1 0.1501165074 0.00091 0.896805421 0.00194 0.896259394 0.06547 0.98164125 0.0208 0.906989078 0.01159 0.2264630936 0.00888 0.525239052
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.tnormal dist.subspace.tnormal dist.projection.tnormal time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.lpca dist.subspace.lpca dist.projection.lpca time.clpca dist.subspace.clpca dist.projection.clpca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.5 2.61295 0.3403757357 0.3403757357 -1 0.3479059356 0.3479059356 6e-04 0.825472384 0.985301393 0.00064 0.898420853 0.898420853 0.03221 0.819788638 0.984818682 0.08248 0.830116397 0.985435015 0.00525 0.5383179747 0.5383179747 0.03822 0.684322002 0.87142619
|
||||||
|
200 50.5 2.69877 0.2336318211 0.2336318211 -1 0.2242954626 0.2242954626 0.00062 0.805871266 0.986137373 0.00083 0.894115576 0.894115576 0.04415 0.805329005 0.98705536 0.09073 0.813651382 0.988806844 0.00584 0.3463104772 0.3463104772 0.03612 0.607852417 0.779137501
|
||||||
|
300 50.5 2.84797 0.17056658598 0.17056658598 -1 0.177746684 0.177746684 0.00065 0.809654889 0.993082221 0.00099 0.901032982 0.901032982 0.05754 0.809330963 0.994445446 0.09848 0.817965106 0.995003702 0.00587 0.23915106584 0.23915106584 0.03905 0.564592137 0.722673845
|
||||||
|
500 50.5 2.99338 0.1421103262 0.1421103262 -1 0.14097689578 0.14097689578 8e-04 0.801989292 0.996579785 0.00143 0.911579678 0.911579678 0.06962 0.804787754 0.997525521 0.11444 0.814022873 0.99807076 0.00645 0.1904496139 0.1904496139 0.04446 0.506345649 0.646893436
|
||||||
|
750 50.5 3.13791 0.1099093089 0.1099093089 -1 0.11200510849 0.11200510849 0.00085 0.803412662 0.997339272 0.0018 0.90974563 0.90974563 0.06278 0.805842802 0.998157979 0.12508 0.816324572 0.998309878 0.00736 0.14247728474 0.14247728474 0.04683 0.488465089 0.623423991
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm time.tnormal dist.subspace.tnormal time.pca dist.subspace.pca time.hopca dist.subspace.hopca time.lpca dist.subspace.lpca time.clpca dist.subspace.clpca time.tsir dist.subspace.tsir time.mgcca dist.subspace.mgcca
|
||||||
|
100 50.5 3.09356 0.3156118185 -1 0.333911302 0.00073 0.878676674 0.00075 0.896222186 0.09113 0.806536411 0.04383 0.896834118 0.01101 0.497285466 0.01193 0.575079476
|
||||||
|
200 50.5 3.12299 0.2077510819 -1 0.2436035005 0.00066 0.874724783 0.00079 0.889384878 0.09807 0.733886862 0.04732 0.894434769 0.01046 0.324664115 0.00669 0.576863144
|
||||||
|
300 50.5 3.16421 0.1861360687 -1 0.2008576102 0.00078 0.870909384 0.00102 0.886922724 0.11782 0.7481989 0.0503 0.893348533 0.01119 0.2976653034 0.00771 0.585351528
|
||||||
|
500 50.5 3.2918 0.1462556459 -1 0.1719980649 0.00085 0.874809929 0.00141 0.888039936 0.16421 0.725505512 0.06 0.894785099 0.01279 0.2572542268 0.00873 0.585034361
|
||||||
|
750 50.5 3.34843 0.1180667497 -1 0.1269961448 0.00084 0.87183362 0.00184 0.888721897 0.22709 0.667001863 0.07148 0.891578538 0.01242 0.2446346482 0.00958 0.589975817
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
sample.size rep time.gmlm dist.subspace.gmlm dist.projection.gmlm time.tnormal dist.subspace.tnormal dist.projection.tnormal time.pca dist.subspace.pca dist.projection.pca time.hopca dist.subspace.hopca dist.projection.hopca time.lpca dist.subspace.lpca dist.projection.lpca time.clpca dist.subspace.clpca dist.projection.clpca time.tsir dist.subspace.tsir dist.projection.tsir time.mgcca dist.subspace.mgcca dist.projection.mgcca
|
||||||
|
100 50.5 3.03076 0.3036349361 0.3036349361 -1 0.2868643362 0.2868643362 0.00073 0.818177689 0.984435713 0.00063 0.755946643 0.755946643 0.05648 0.819181741 0.953681387 0.0867 0.822312566 0.986781953 0.00613 0.4019671947 0.4019671947 0.04573 0.588313518 0.75377487
|
||||||
|
200 50.5 3.10805 0.1992945381 0.1992945381 -1 0.2020230182 0.2020230182 0.00074 0.814338816 0.987890451 0.00093 0.690811228 0.690811228 0.05801 0.81457008 0.961377759 0.08991 0.825394861 0.991647771 0.00599 0.2511707246 0.2511707246 0.04656 0.476761868 0.629632653
|
||||||
|
300 50.5 3.3052 0.1546113425 0.1546113425 -1 0.1537117318 0.1537117318 0.00075 0.798497064 0.993693989 0.00099 0.729634976 0.729634976 0.05696 0.820224046 0.982252313 0.09375 0.806834524 0.992228985 0.00636 0.1790215654 0.1790215654 0.04835 0.37049092 0.4917621054
|
||||||
|
500 50.5 3.88258 0.1130383355 0.1130383355 -1 0.11118393534 0.11118393534 0.00095 0.817445534 0.996960022 0.00148 0.651972301 0.651972301 0.07627 0.816560225 0.989645322 0.11561 0.837453121 0.996154736 0.00715 0.1217933129 0.1217933129 0.06199 0.2348135641 0.3010364601
|
||||||
|
750 50.5 3.97107 0.09585699718 0.09585699718 -1 0.09463671116 0.09463671116 0.00106 0.829576522 0.996509465 0.00179 0.6203066621 0.6203066621 0.09122 0.820908322 0.988956119 0.13158 0.849250091 0.997066922 0.00765 0.10393120662 0.10393120662 0.06512 0.1950222554 0.2482127769
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
dim exact MC MCthrd
|
||||||
|
1 73.5012 2479.1000 5306.4984
|
||||||
|
2 59.9796 8664.1598 9779.8466
|
||||||
|
3 64.5056 12914.5464 16183.0764
|
||||||
|
4 53.2328 16508.2854 17870.7350
|
||||||
|
5 96.9536 24127.6058 21767.8086
|
||||||
|
6 63.4716 28597.2378 26124.3540
|
||||||
|
7 72.9282 40515.4708 29894.6794
|
||||||
|
8 104.7062 50231.4678 38400.2156
|
||||||
|
9 92.1626 63120.4664 43210.9808
|
||||||
|
10 137.0708 78229.7892 45129.4508
|
||||||
|
11 199.2572 90637.2042 52743.8126
|
||||||
|
12 416.4218 110416.5194 53946.7448
|
||||||
|
13 1300.8774 113405.5306 62789.0696
|
||||||
|
14 1613.6268 147592.3898 65697.4700
|
||||||
|
15 4308.1456 150118.8002 77611.7774
|
||||||
|
16 6218.6118 169137.1948 79600.3404
|
||||||
|
17 19130.1692 193395.2364 84645.9674
|
||||||
|
18 35013.7520 191694.0282 89612.3666
|
||||||
|
19 79783.0592 219430.7520 89025.0298
|
||||||
|
20 164029.9676 221552.7670 97877.6912
|
||||||
|
21 344260.9886 NaN NaN
|
||||||
|
22 727962.7158 NaN NaN
|
||||||
|
23 1575072.6418 NaN NaN
|
||||||
|
24 3408482.1812 NaN NaN
|
||||||
|
30 NaN 432889.7808 159195.6658
|
||||||
|
40 NaN 675539.2932 227844.4468
|
||||||
|
50 NaN 921443.9964 322103.8008
|
||||||
|
60 NaN 1249191.1684 365132.1606
|
||||||
|
70 NaN 1634723.4776 463334.2788
|
||||||
|
80 NaN 2020644.9422 517770.1918
|
||||||
|
90 NaN 2341498.2072 608483.0752
|
||||||
|
100 NaN 2912498.1526 720652.3034
|
||||||
|
110 NaN 3485288.1716 872467.7020
|
||||||
|
120 NaN 4134359.9114 985625.2118
|
||||||
|
130 NaN 4619163.4688 1064965.6310
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
rho order beta.version dist.subspace.gmlm dist.subspace.tsir dist.subspace.sir
|
||||||
|
0 2 1 0.05561496033 0.05061728586 0.0937440796
|
||||||
|
0.1 2 1 0.05695142139 0.0470966691 0.0926586359
|
||||||
|
0.2 2 1 0.06091107141 0.05239605337 0.0908641089
|
||||||
|
0.3 2 1 0.06307756487 0.05222743771 0.1017065255
|
||||||
|
0.4 2 1 0.0660642872 0.06165316957 0.101956927
|
||||||
|
0.5 2 1 0.0607144752 0.06296036226 0.1132399708
|
||||||
|
0.6 2 1 0.0680270013 0.07738945736 0.1338582214
|
||||||
|
0.7 2 1 0.08308930348 0.1022448411 0.1719732064
|
||||||
|
0.8 2 1 0.09393400477 0.1391586209 0.26938282
|
||||||
|
0 3 1 0.0417583747 0.04632848929 0.2461782677
|
||||||
|
0.1 3 1 0.0434276533 0.05218873186 0.2213153802
|
||||||
|
0.2 3 1 0.04597206968 0.0570669677 0.249537892
|
||||||
|
0.3 3 1 0.04502304399 0.0614214213 0.27102197
|
||||||
|
0.4 3 1 0.0473382351 0.0792024647 0.319204514
|
||||||
|
0.5 3 1 0.0566416444 0.1123603747 0.402911276
|
||||||
|
0.6 3 1 0.0635449054 0.1726250727 0.514738707
|
||||||
|
0.7 3 1 0.0790087287 0.3119839854 0.696578663
|
||||||
|
0.8 3 1 0.1049010818 0.625019551 0.901082248
|
||||||
|
0 4 1 0.0318544676 0.0756838662 0.763664746
|
||||||
|
0.1 4 1 0.0291616189 0.0732198203 0.724917238
|
||||||
|
0.2 4 1 0.0339593815 0.0892676958 0.727812142
|
||||||
|
0.3 4 1 0.033896654 0.1217472737 0.802679627
|
||||||
|
0.4 4 1 0.0421267215 0.1792376247 0.875051648
|
||||||
|
0.5 4 1 0.0497214363 0.2948337295 0.920789642
|
||||||
|
0.6 4 1 0.0649548512 0.516274211 0.961926272
|
||||||
|
0.7 4 1 0.0796107149 0.82163525 0.969975565
|
||||||
|
0.8 4 1 0.1319282631 0.952178592 0.969918026
|
||||||
|
0 2 2 0.04419882713 0.05007045448 0.0981350583
|
||||||
|
0.1 2 2 0.04771625635 0.06317718403 0.0994558305
|
||||||
|
0.2 2 2 0.05842513124 0.07257500657 0.1393608074
|
||||||
|
0.3 2 2 0.06074603789 0.0937379307 0.1469685419
|
||||||
|
0.4 2 2 0.06804359146 0.13265083527 0.1811307013
|
||||||
|
0.5 2 2 0.08431280029 0.15490492217 0.2644350099
|
||||||
|
0.6 2 2 0.0972256132 0.2322527248 0.3252648145
|
||||||
|
0.7 2 2 0.11758589026 0.3462559988 0.493115596
|
||||||
|
0.8 2 2 0.17756305 0.5735205103 0.688565503
|
||||||
|
0 3 2 0.03463284148 0.0894755449 0.413328838
|
||||||
|
0.1 3 2 0.04180307759 0.1390793707 0.497959545
|
||||||
|
0.2 3 2 0.0460221535 0.2086373171 0.63903354
|
||||||
|
0.3 3 2 0.0537508593 0.3124281256 0.748414045
|
||||||
|
0.4 3 2 0.060618005 0.495823454 0.890088075
|
||||||
|
0.5 3 2 0.0853542084 0.6712004401 0.956569545
|
||||||
|
0.6 3 2 0.0910737894 0.853848871 0.985803877
|
||||||
|
0.7 3 2 0.1435666309 0.965105066 0.995326644
|
||||||
|
0.8 3 2 0.1842180974 0.993108512 0.996942006
|
||||||
|
0 4 2 0.03189456039 0.260763794 0.958145847
|
||||||
|
0.1 4 2 0.03256901682 0.413864983 0.981053177
|
||||||
|
0.2 4 2 0.03944012707 0.635137383 0.99257835
|
||||||
|
0.3 4 2 0.0491580489 0.87045687 0.99829348
|
||||||
|
0.4 4 2 0.0633184796 0.961679634 0.999828802
|
||||||
|
0.5 4 2 0.0785727515 0.996049666 0.999905562
|
||||||
|
0.6 4 2 0.118468394 0.99986322 0.999134535
|
||||||
|
0.7 4 2 0.1952382107 0.999994091 0.9319280744
|
||||||
|
0.8 4 2 0.055013371 0.999999997 0.87224130919
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
% Authomatically generated by `dataAnalysis/chess.R`
|
||||||
|
|
||||||
|
\documentclass{standalone}
|
||||||
|
|
||||||
|
\usepackage[LSB, T1]{fontenc}
|
||||||
|
\usepackage{chessboard}
|
||||||
|
\usepackage{skak}
|
||||||
|
\usepackage{tikz}
|
||||||
|
\usepackage{amsmath}
|
||||||
|
\usepackage{xcolor}
|
||||||
|
|
||||||
|
\setboardfontencoding{LSB}
|
||||||
|
|
||||||
|
\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
|
||||||
|
|
||||||
|
\definecolor{col1}{HTML}{5F1415} \definecolor{col2}{HTML}{641818} \definecolor{col3}{HTML}{681B1C} \definecolor{col4}{HTML}{6D1F1F} \definecolor{col5}{HTML}{722222} \definecolor{col6}{HTML}{772526} \definecolor{col7}{HTML}{7C2829} \definecolor{col8}{HTML}{812B2C} \definecolor{col9}{HTML}{862F2F} \definecolor{col10}{HTML}{8B3232} \definecolor{col11}{HTML}{903535} \definecolor{col12}{HTML}{953838} \definecolor{col13}{HTML}{9A3B3B} \definecolor{col14}{HTML}{9F3E3E} \definecolor{col15}{HTML}{A44141} \definecolor{col16}{HTML}{A94444} \definecolor{col17}{HTML}{AE4747} \definecolor{col18}{HTML}{B34A4A} \definecolor{col19}{HTML}{B74E4E} \definecolor{col20}{HTML}{BA5353} \definecolor{col21}{HTML}{BD5758} \definecolor{col22}{HTML}{C05C5C} \definecolor{col23}{HTML}{C26061} \definecolor{col24}{HTML}{C56565} \definecolor{col25}{HTML}{C7696A} \definecolor{col26}{HTML}{CA6E6E} \definecolor{col27}{HTML}{CD7272} \definecolor{col28}{HTML}{CF7677} \definecolor{col29}{HTML}{D17A7B} \definecolor{col30}{HTML}{D47F7F} \definecolor{col31}{HTML}{D68383} \definecolor{col32}{HTML}{D88787} \definecolor{col33}{HTML}{DA8B8C} \definecolor{col34}{HTML}{DC8F90} \definecolor{col35}{HTML}{DE9394} \definecolor{col36}{HTML}{E09898} \definecolor{col37}{HTML}{E29C9C} \definecolor{col38}{HTML}{E4A0A0} \definecolor{col39}{HTML}{E6A4A4} \definecolor{col40}{HTML}{E8A8A8} \definecolor{col41}{HTML}{E9ABAC} \definecolor{col42}{HTML}{EBAFAF} \definecolor{col43}{HTML}{EDB3B3} \definecolor{col44}{HTML}{EEB7B7} \definecolor{col45}{HTML}{EFBBBB} \definecolor{col46}{HTML}{F1BEBF} \definecolor{col47}{HTML}{F2C2C2} \definecolor{col48}{HTML}{F3C6C6} \definecolor{col49}{HTML}{F4C9C9} \definecolor{col50}{HTML}{F5CDCD} \definecolor{col51}{HTML}{F6D0D0} \definecolor{col52}{HTML}{F7D4D4} \definecolor{col53}{HTML}{F8D7D7} \definecolor{col54}{HTML}{F8DADA} \definecolor{col55}{HTML}{F9DEDE} \definecolor{col56}{HTML}{F9E1E1} \definecolor{col57}{HTML}{FAE4E4} \definecolor{col58}{HTML}{FAE7E7} \definecolor{col59}{HTML}{FAEAEA} \definecolor{col60}{HTML}{FAECEC} \definecolor{col61}{HTML}{F9EFEF} \definecolor{col62}{HTML}{F9F1F1} \definecolor{col63}{HTML}{F8F4F4} \definecolor{col64}{HTML}{F7F6F6} \definecolor{col65}{HTML}{F6F6F7} \definecolor{col66}{HTML}{F4F5F8} \definecolor{col67}{HTML}{F1F3F8} \definecolor{col68}{HTML}{EFF1F8} \definecolor{col69}{HTML}{ECEFF8} \definecolor{col70}{HTML}{EAEDF8} \definecolor{col71}{HTML}{E7EBF8} \definecolor{col72}{HTML}{E4E8F8} \definecolor{col73}{HTML}{E1E6F7} \definecolor{col74}{HTML}{DEE3F7} \definecolor{col75}{HTML}{DAE1F6} \definecolor{col76}{HTML}{D7DEF5} \definecolor{col77}{HTML}{D4DCF4} \definecolor{col78}{HTML}{D0D9F4} \definecolor{col79}{HTML}{CDD6F3} \definecolor{col80}{HTML}{C9D3F2} \definecolor{col81}{HTML}{C5D0F1} \definecolor{col82}{HTML}{C2CDEF} \definecolor{col83}{HTML}{BECAEE} \definecolor{col84}{HTML}{BAC7ED} \definecolor{col85}{HTML}{B6C4EC} \definecolor{col86}{HTML}{B2C1EA} \definecolor{col87}{HTML}{AEBEE9} \definecolor{col88}{HTML}{AABAE8} \definecolor{col89}{HTML}{A6B7E6} \definecolor{col90}{HTML}{A2B4E5} \definecolor{col91}{HTML}{9EB1E3} \definecolor{col92}{HTML}{9AADE2} \definecolor{col93}{HTML}{95AAE0} \definecolor{col94}{HTML}{91A6DF} \definecolor{col95}{HTML}{8DA3DD} \definecolor{col96}{HTML}{88A0DB} \definecolor{col97}{HTML}{849CDA} \definecolor{col98}{HTML}{7F99D8} \definecolor{col99}{HTML}{7A95D6} \definecolor{col100}{HTML}{7592D4} \definecolor{col101}{HTML}{708ED3} \definecolor{col102}{HTML}{6B8BD1} \definecolor{col103}{HTML}{6687CF} \definecolor{col104}{HTML}{6184CD} \definecolor{col105}{HTML}{5B80CC} \definecolor{col106}{HTML}{567DCA} \definecolor{col107}{HTML}{5079C8} \definecolor{col108}{HTML}{4975C7} \definecolor{col109}{HTML}{4372C5} \definecolor{col110}{HTML}{3B6EC3} \definecolor{col111}{HTML}{356BC1} \definecolor{col112}{HTML}{3167BC} \definecolor{col113}{HTML}{2E63B6} \definecolor{col114}{HTML}{2B60B1} \definecolor{col115}{HTML}{275CAC} \definecolor{col116}{HTML}{2459A7} \definecolor{col117}{HTML}{2055A2} \definecolor{col118}{HTML}{1C529D} \definecolor{col119}{HTML}{174E98} \definecolor{col120}{HTML}{124B93} \definecolor{col121}{HTML}{0C478E} \definecolor{col122}{HTML}{05448A} \definecolor{col123}{HTML}{004085} \definecolor{col124}{HTML}{003D80} \definecolor{col125}{HTML}{00397C} \definecolor{col126}{HTML}{003678} \definecolor{col127}{HTML}{003274} \definecolor{col128}{HTML}{002F70}
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}
|
||||||
|
|
||||||
|
\coordinate (pawn) at (0, 0);
|
||||||
|
\coordinate (knight) at (5, 0);
|
||||||
|
\coordinate (bishop) at (10, 0);
|
||||||
|
\coordinate (rook) at (0, -5.2);
|
||||||
|
\coordinate (queen) at (5, -5.2);
|
||||||
|
\coordinate (king) at (10, -5.2);
|
||||||
|
|
||||||
|
\node (pawn) at (pawn) {\chessboard[color=col64,colorbackfield=a8,color=col107,colorbackfield=a7,color=col76,colorbackfield=a6,color=col66,colorbackfield=a5,color=col67,colorbackfield=a4,color=col65,colorbackfield=a3,color=col66,colorbackfield=a2,color=col64,colorbackfield=a1,color=col64,colorbackfield=b8,color=col102,colorbackfield=b7,color=col75,colorbackfield=b6,color=col66,colorbackfield=b5,color=col66,colorbackfield=b4,color=col65,colorbackfield=b3,color=col66,colorbackfield=b2,color=col64,colorbackfield=b1,color=col64,colorbackfield=c8,color=col101,colorbackfield=c7,color=col75,colorbackfield=c6,color=col66,colorbackfield=c5,color=col66,colorbackfield=c4,color=col65,colorbackfield=c3,color=col66,colorbackfield=c2,color=col64,colorbackfield=c1,color=col64,colorbackfield=d8,color=col102,colorbackfield=d7,color=col74,colorbackfield=d6,color=col66,colorbackfield=d5,color=col67,colorbackfield=d4,color=col65,colorbackfield=d3,color=col66,colorbackfield=d2,color=col64,colorbackfield=d1,color=col64,colorbackfield=e8,color=col102,colorbackfield=e7,color=col74,colorbackfield=e6,color=col66,colorbackfield=e5,color=col67,colorbackfield=e4,color=col65,colorbackfield=e3,color=col66,colorbackfield=e2,color=col64,colorbackfield=e1,color=col64,colorbackfield=f8,color=col105,colorbackfield=f7,color=col76,colorbackfield=f6,color=col66,colorbackfield=f5,color=col67,colorbackfield=f4,color=col65,colorbackfield=f3,color=col66,colorbackfield=f2,color=col64,colorbackfield=f1,color=col64,colorbackfield=g8,color=col103,colorbackfield=g7,color=col74,colorbackfield=g6,color=col66,colorbackfield=g5,color=col67,colorbackfield=g4,color=col65,colorbackfield=g3,color=col66,colorbackfield=g2,color=col64,colorbackfield=g1,color=col64,colorbackfield=h8,color=col110,colorbackfield=h7,color=col76,colorbackfield=h6,color=col67,colorbackfield=h5,color=col67,colorbackfield=h4,color=col65,colorbackfield=h3,color=col67,colorbackfield=h2,color=col64,colorbackfield=h1]};
|
||||||
|
\node (knight) at (knight) {\chessboard[color=col88,colorbackfield=a8,color=col82,colorbackfield=a7,color=col78,colorbackfield=a6,color=col69,colorbackfield=a5,color=col70,colorbackfield=a4,color=col66,colorbackfield=a3,color=col70,colorbackfield=a2,color=col68,colorbackfield=a1,color=col105,colorbackfield=b8,color=col83,colorbackfield=b7,color=col75,colorbackfield=b6,color=col67,colorbackfield=b5,color=col68,colorbackfield=b4,color=col66,colorbackfield=b3,color=col68,colorbackfield=b2,color=col66,colorbackfield=b1,color=col88,colorbackfield=c8,color=col81,colorbackfield=c7,color=col77,colorbackfield=c6,color=col68,colorbackfield=c5,color=col69,colorbackfield=c4,color=col66,colorbackfield=c3,color=col69,colorbackfield=c2,color=col67,colorbackfield=c1,color=col104,colorbackfield=d8,color=col88,colorbackfield=d7,color=col79,colorbackfield=d6,color=col69,colorbackfield=d5,color=col70,colorbackfield=d4,color=col66,colorbackfield=d3,color=col70,colorbackfield=d2,color=col68,colorbackfield=d1,color=col88,colorbackfield=e8,color=col79,colorbackfield=e7,color=col74,colorbackfield=e6,color=col67,colorbackfield=e5,color=col68,colorbackfield=e4,color=col65,colorbackfield=e3,color=col68,colorbackfield=e2,color=col66,colorbackfield=e1,color=col92,colorbackfield=f8,color=col83,colorbackfield=f7,color=col78,colorbackfield=f6,color=col68,colorbackfield=f5,color=col70,colorbackfield=f4,color=col66,colorbackfield=f3,color=col70,colorbackfield=f2,color=col68,colorbackfield=f1,color=col125,colorbackfield=g8,color=col91,colorbackfield=g7,color=col78,colorbackfield=g6,color=col69,colorbackfield=g5,color=col70,colorbackfield=g4,color=col67,colorbackfield=g3,color=col70,colorbackfield=g2,color=col68,colorbackfield=g1,color=col93,colorbackfield=h8,color=col83,colorbackfield=h7,color=col78,colorbackfield=h6,color=col69,colorbackfield=h5,color=col70,colorbackfield=h4,color=col66,colorbackfield=h3,color=col70,colorbackfield=h2,color=col68,colorbackfield=h1]};
|
||||||
|
\node (bishop) at (bishop) {\chessboard[color=col73,colorbackfield=a8,color=col72,colorbackfield=a7,color=col71,colorbackfield=a6,color=col80,colorbackfield=a5,color=col70,colorbackfield=a4,color=col76,colorbackfield=a3,color=col67,colorbackfield=a2,color=col66,colorbackfield=a1,color=col70,colorbackfield=b8,color=col70,colorbackfield=b7,color=col69,colorbackfield=b6,color=col74,colorbackfield=b5,color=col69,colorbackfield=b4,color=col72,colorbackfield=b3,color=col66,colorbackfield=b2,color=col68,colorbackfield=b1,color=col71,colorbackfield=c8,color=col66,colorbackfield=c7,color=col66,colorbackfield=c6,color=col78,colorbackfield=c5,color=col66,colorbackfield=c4,color=col68,colorbackfield=c3,color=col66,colorbackfield=c2,color=col67,colorbackfield=c1,color=col68,colorbackfield=d8,color=col68,colorbackfield=d7,color=col67,colorbackfield=d6,color=col70,colorbackfield=d5,color=col67,colorbackfield=d4,color=col69,colorbackfield=d3,color=col66,colorbackfield=d2,color=col67,colorbackfield=d1,color=col68,colorbackfield=e8,color=col68,colorbackfield=e7,color=col67,colorbackfield=e6,color=col71,colorbackfield=e5,color=col67,colorbackfield=e4,color=col70,colorbackfield=e3,color=col66,colorbackfield=e2,color=col65,colorbackfield=e1,color=col73,colorbackfield=f8,color=col70,colorbackfield=f7,color=col68,colorbackfield=f6,color=col82,colorbackfield=f5,color=col66,colorbackfield=f4,color=col75,colorbackfield=f3,color=col67,colorbackfield=f2,color=col65,colorbackfield=f1,color=col71,colorbackfield=g8,color=col70,colorbackfield=g7,color=col69,colorbackfield=g6,color=col73,colorbackfield=g5,color=col70,colorbackfield=g4,color=col70,colorbackfield=g3,color=col66,colorbackfield=g2,color=col79,colorbackfield=g1,color=col73,colorbackfield=h8,color=col72,colorbackfield=h7,color=col70,colorbackfield=h6,color=col79,colorbackfield=h5,color=col70,colorbackfield=h4,color=col75,colorbackfield=h3,color=col67,colorbackfield=h2,color=col64,colorbackfield=h1]};
|
||||||
|
\node (rook) at (rook) {\chessboard[color=col104,colorbackfield=a8,color=col95,colorbackfield=a7,color=col105,colorbackfield=a6,color=col103,colorbackfield=a5,color=col97,colorbackfield=a4,color=col81,colorbackfield=a3,color=col84,colorbackfield=a2,color=col89,colorbackfield=a1,color=col95,colorbackfield=b8,color=col90,colorbackfield=b7,color=col97,colorbackfield=b6,color=col97,colorbackfield=b5,color=col92,colorbackfield=b4,color=col77,colorbackfield=b3,color=col81,colorbackfield=b2,color=col84,colorbackfield=b1,color=col105,colorbackfield=c8,color=col96,colorbackfield=c7,color=col106,colorbackfield=c6,color=col104,colorbackfield=c5,color=col98,colorbackfield=c4,color=col82,colorbackfield=c3,color=col85,colorbackfield=c2,color=col90,colorbackfield=c1,color=col89,colorbackfield=d8,color=col86,colorbackfield=d7,color=col91,colorbackfield=d6,color=col91,colorbackfield=d5,color=col87,colorbackfield=d4,color=col74,colorbackfield=d3,color=col79,colorbackfield=d2,color=col80,colorbackfield=d1,color=col93,colorbackfield=e8,color=col88,colorbackfield=e7,color=col95,colorbackfield=e6,color=col94,colorbackfield=e5,color=col90,colorbackfield=e4,color=col75,colorbackfield=e3,color=col80,colorbackfield=e2,color=col83,colorbackfield=e1,color=col100,colorbackfield=f8,color=col93,colorbackfield=f7,color=col102,colorbackfield=f6,color=col100,colorbackfield=f5,color=col95,colorbackfield=f4,color=col79,colorbackfield=f3,color=col83,colorbackfield=f2,color=col88,colorbackfield=f1,color=col97,colorbackfield=g8,color=col92,colorbackfield=g7,color=col99,colorbackfield=g6,color=col98,colorbackfield=g5,color=col93,colorbackfield=g4,color=col79,colorbackfield=g3,color=col83,colorbackfield=g2,color=col84,colorbackfield=g1,color=col106,colorbackfield=h8,color=col96,colorbackfield=h7,color=col108,colorbackfield=h6,color=col105,colorbackfield=h5,color=col99,colorbackfield=h4,color=col82,colorbackfield=h3,color=col85,colorbackfield=h2,color=col91,colorbackfield=h1]};
|
||||||
|
\node (queen) at (queen) {\chessboard[color=col126,colorbackfield=a8,color=col125,colorbackfield=a7,color=col124,colorbackfield=a6,color=col122,colorbackfield=a5,color=col117,colorbackfield=a4,color=col110,colorbackfield=a3,color=col108,colorbackfield=a2,color=col99,colorbackfield=a1,color=col125,colorbackfield=b8,color=col125,colorbackfield=b7,color=col124,colorbackfield=b6,color=col123,colorbackfield=b5,color=col118,colorbackfield=b4,color=col111,colorbackfield=b3,color=col110,colorbackfield=b2,color=col101,colorbackfield=b1,color=col124,colorbackfield=c8,color=col126,colorbackfield=c7,color=col126,colorbackfield=c6,color=col125,colorbackfield=c5,color=col122,colorbackfield=c4,color=col117,colorbackfield=c3,color=col115,colorbackfield=c2,color=col105,colorbackfield=c1,color=col109,colorbackfield=d8,color=col118,colorbackfield=d7,color=col120,colorbackfield=d6,color=col120,colorbackfield=d5,color=col122,colorbackfield=d4,color=col124,colorbackfield=d3,color=col122,colorbackfield=d2,color=col111,colorbackfield=d1,color=col124,colorbackfield=e8,color=col128,colorbackfield=e7,color=col128,colorbackfield=e6,color=col127,colorbackfield=e5,color=col126,colorbackfield=e4,color=col122,colorbackfield=e3,color=col120,colorbackfield=e2,color=col109,colorbackfield=e1,color=col125,colorbackfield=f8,color=col126,colorbackfield=f7,color=col126,colorbackfield=f6,color=col124,colorbackfield=f5,color=col120,colorbackfield=f4,color=col114,colorbackfield=f3,color=col113,colorbackfield=f2,color=col103,colorbackfield=f1,color=col127,colorbackfield=g8,color=col125,colorbackfield=g7,color=col124,colorbackfield=g6,color=col122,colorbackfield=g5,color=col115,colorbackfield=g4,color=col106,colorbackfield=g3,color=col105,colorbackfield=g2,color=col97,colorbackfield=g1,color=col127,colorbackfield=h8,color=col123,colorbackfield=h7,color=col122,colorbackfield=h6,color=col120,colorbackfield=h5,color=col112,colorbackfield=h4,color=col102,colorbackfield=h3,color=col101,colorbackfield=h2,color=col94,colorbackfield=h1]};
|
||||||
|
\node (king) at (king) {\chessboard[color=col55,colorbackfield=a8,color=col56,colorbackfield=a7,color=col56,colorbackfield=a6,color=col56,colorbackfield=a5,color=col56,colorbackfield=a4,color=col55,colorbackfield=a3,color=col58,colorbackfield=a2,color=col65,colorbackfield=a1,color=col57,colorbackfield=b8,color=col59,colorbackfield=b7,color=col59,colorbackfield=b6,color=col58,colorbackfield=b5,color=col58,colorbackfield=b4,color=col57,colorbackfield=b3,color=col60,colorbackfield=b2,color=col69,colorbackfield=b1,color=col58,colorbackfield=c8,color=col58,colorbackfield=c7,color=col58,colorbackfield=c6,color=col58,colorbackfield=c5,color=col58,colorbackfield=c4,color=col58,colorbackfield=c3,color=col60,colorbackfield=c2,color=col65,colorbackfield=c1,color=col57,colorbackfield=d8,color=col58,colorbackfield=d7,color=col58,colorbackfield=d6,color=col57,colorbackfield=d5,color=col57,colorbackfield=d4,color=col57,colorbackfield=d3,color=col59,colorbackfield=d2,color=col64,colorbackfield=d1,color=col57,colorbackfield=e8,color=col58,colorbackfield=e7,color=col58,colorbackfield=e6,color=col58,colorbackfield=e5,color=col58,colorbackfield=e4,color=col58,colorbackfield=e3,color=col60,colorbackfield=e2,color=col64,colorbackfield=e1,color=col62,colorbackfield=f8,color=col62,colorbackfield=f7,color=col62,colorbackfield=f6,color=col61,colorbackfield=f5,color=col61,colorbackfield=f4,color=col61,colorbackfield=f3,color=col63,colorbackfield=f2,color=col67,colorbackfield=f1,color=col61,colorbackfield=g8,color=col62,colorbackfield=g7,color=col62,colorbackfield=g6,color=col62,colorbackfield=g5,color=col62,colorbackfield=g4,color=col61,colorbackfield=g3,color=col63,colorbackfield=g2,color=col70,colorbackfield=g1,color=col60,colorbackfield=h8,color=col61,colorbackfield=h7,color=col61,colorbackfield=h6,color=col60,colorbackfield=h5,color=col60,colorbackfield=h4,color=col60,colorbackfield=h3,color=col61,colorbackfield=h2,color=col64,colorbackfield=h1]};
|
||||||
|
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (pawn.north) {Pawn};
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (knight.north) {Knight};
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (bishop.north) {Bishop};
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (rook.north) {Rook};
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (queen.north) {Queen};
|
||||||
|
\node[anchor = north, yshift = -0.4em] at (king.north) {King};
|
||||||
|
|
||||||
|
\end{tikzpicture}
|
||||||
|
\end{document}
|
|
@ -0,0 +1,95 @@
|
||||||
|
%%% R code to generate the input data files from corresponding simulation logs
|
||||||
|
% R> setwd("~/Work/tensorPredictors")
|
||||||
|
% R>
|
||||||
|
% R> for (sim.name in c("2a")) {
|
||||||
|
% R> pattern <- paste0("sim\\_", sim.name, "\\_ising\\-[0-9T]*\\.csv")
|
||||||
|
% R> log.file <- sort(
|
||||||
|
% R> list.files(path = "sim/", pattern = pattern, full.names = TRUE),
|
||||||
|
% R> decreasing = TRUE
|
||||||
|
% R> )[[1]]
|
||||||
|
% R>
|
||||||
|
% R> sim <- read.csv(log.file)
|
||||||
|
% R>
|
||||||
|
% R> aggr <- aggregate(sim[, names(sim) != "sample.size"], list(sample.size = sim$sample.size), mean)
|
||||||
|
% R>
|
||||||
|
% R> write.table(aggr, file = paste0("LaTeX/plots/aggr-", sim.name, "-ising.csv"), row.names = FALSE, quote = FALSE)
|
||||||
|
% R> }
|
||||||
|
\documentclass[border=0cm]{standalone}
|
||||||
|
|
||||||
|
\usepackage{tikz}
|
||||||
|
\usepackage{pgfplots}
|
||||||
|
\usepackage{amssymb, bm}
|
||||||
|
|
||||||
|
\definecolor{exact}{RGB}{230,0,0}
|
||||||
|
\definecolor{MC}{RGB}{30,180,30}
|
||||||
|
\definecolor{MCthrd}{RGB}{0,0,230}
|
||||||
|
|
||||||
|
\pgfplotsset{
|
||||||
|
compat=newest,
|
||||||
|
grid=both,
|
||||||
|
grid style={gray!15}
|
||||||
|
}
|
||||||
|
\tikzset{
|
||||||
|
legend entry/.style={
|
||||||
|
mark = *,
|
||||||
|
mark size = 1pt,
|
||||||
|
mark indices = {2},
|
||||||
|
line width=0.8pt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}[>=latex]
|
||||||
|
|
||||||
|
\begin{axis}[
|
||||||
|
name=perft,
|
||||||
|
xmode = normal, ymode = log,
|
||||||
|
xtick={10, 30, 50, 70, 90, 110, 130},
|
||||||
|
mark = *,
|
||||||
|
mark size = 1pt,
|
||||||
|
ytick={1,1e1,1e2,1e3,1e4,1e5,1e6,1e7},
|
||||||
|
yticklabels={$1\mu s$,$10\mu s$,$100\mu s$,$1 ms$,$10 ms$,$100 ms$,$1 s$,$10 s$},
|
||||||
|
xlabel = {Dimension $p$},
|
||||||
|
ylabel = {Time}
|
||||||
|
]
|
||||||
|
\addplot[
|
||||||
|
only marks,
|
||||||
|
color = exact
|
||||||
|
] table[x = dim, y = exact] {aggr-ising-perft-m2.csv};
|
||||||
|
\addplot[
|
||||||
|
only marks,
|
||||||
|
color = MC
|
||||||
|
] table[x = dim, y = MC] {aggr-ising-perft-m2.csv};
|
||||||
|
\addplot[
|
||||||
|
only marks,
|
||||||
|
color = MCthrd
|
||||||
|
] table[x = dim, y = MCthrd] {aggr-ising-perft-m2.csv};
|
||||||
|
|
||||||
|
\addplot[smooth, domain = 8:25, color = exact, samples = 64] { 2^(x - 3) };
|
||||||
|
\addplot[smooth, domain = 1:130, color = MC, samples = 64] {
|
||||||
|
-18340.8 + 7557.0 * x + 200.3 * x^2
|
||||||
|
};
|
||||||
|
\addplot[smooth, domain = 1:130, color = MCthrd, samples = 64] {
|
||||||
|
8413.21 + 3134.98 * x + 41.87 * x^2
|
||||||
|
};
|
||||||
|
\end{axis}
|
||||||
|
|
||||||
|
\matrix[anchor = west] at (perft.east) {
|
||||||
|
\draw[color=exact, only marks, mark = *, mark size = 1pt, mark indices = {2}] plot coordinates {(0, 0) (.3, 0) (.6, 0)};
|
||||||
|
& \node[anchor=west] {exact}; \\
|
||||||
|
\draw[color=exact, line width = 0.8pt] plot coordinates {(0, 0) (.4, 0)};
|
||||||
|
& \node[anchor=west] {$\mathcal{O}(2^p)$}; \\
|
||||||
|
\draw[color=MC, only marks, mark = *, mark size = 1pt, mark indices = {2}] plot coordinates {(0, 0) (.3, 0) (.6, 0)};
|
||||||
|
& \node[anchor=west] {MC}; \\
|
||||||
|
\draw[color=MC, line width = 0.8pt] plot coordinates {(0, 0) (.4, 0)};
|
||||||
|
& \node[anchor=west] {$\mathcal{O}(p^2)$}; \\
|
||||||
|
\draw[color=MCthrd, only marks, mark = *, mark size = 1pt, mark indices = {2}] plot coordinates {(0, 0) (.3, 0) (.6, 0)};
|
||||||
|
& \node[anchor=west] {MC (8 thrd)}; \\
|
||||||
|
\draw[color=MCthrd, line width = 0.8pt] plot coordinates {(0, 0) (.4, 0)};
|
||||||
|
& \node[anchor=west] {$\mathcal{O}(p^2)$}; \\
|
||||||
|
};
|
||||||
|
|
||||||
|
\node[anchor = south] at (current bounding box.north) {Ising Second Moment Performance Test};
|
||||||
|
|
||||||
|
\end{tikzpicture}
|
||||||
|
\end{document}
|
|
@ -1,7 +1,7 @@
|
||||||
%%% R code to generate the input data files from corresponding simulation logs
|
%%% R code to generate the input data files from corresponding simulation logs
|
||||||
% R> setwd("~/Work/tensorPredictors")
|
% R> setwd("~/Work/tensorPredictors")
|
||||||
% R>
|
% R>
|
||||||
% R> for (sim.name in c("2a")) {
|
% R> for (sim.name in c("2a", "2b", "2c", "2d")) {
|
||||||
% R> pattern <- paste0("sim\\_", sim.name, "\\_ising\\-[0-9T]*\\.csv")
|
% R> pattern <- paste0("sim\\_", sim.name, "\\_ising\\-[0-9T]*\\.csv")
|
||||||
% R> log.file <- sort(
|
% R> log.file <- sort(
|
||||||
% R> list.files(path = "sim/", pattern = pattern, full.names = TRUE),
|
% R> list.files(path = "sim/", pattern = pattern, full.names = TRUE),
|
||||||
|
@ -66,7 +66,7 @@
|
||||||
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2a-ising.csv};
|
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2a-ising.csv};
|
||||||
\end{axis}
|
\end{axis}
|
||||||
\node[anchor = base west, yshift = 0.3em] at (sim-2a.north west) {
|
\node[anchor = base west, yshift = 0.3em] at (sim-2a.north west) {
|
||||||
a: small
|
a: linear dependence on $\mathcal{F}_y \equiv y$
|
||||||
};
|
};
|
||||||
\begin{axis}[
|
\begin{axis}[
|
||||||
name=sim-2b,
|
name=sim-2b,
|
||||||
|
@ -83,7 +83,7 @@
|
||||||
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2b-ising.csv};
|
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2b-ising.csv};
|
||||||
\end{axis}
|
\end{axis}
|
||||||
\node[anchor = base west, yshift = 0.3em] at (sim-2b.north west) {
|
\node[anchor = base west, yshift = 0.3em] at (sim-2b.north west) {
|
||||||
b:
|
b: quadratic dependence on $y$
|
||||||
};
|
};
|
||||||
\begin{axis}[
|
\begin{axis}[
|
||||||
name=sim-2c,
|
name=sim-2c,
|
||||||
|
@ -99,7 +99,7 @@
|
||||||
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2c-ising.csv};
|
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2c-ising.csv};
|
||||||
\end{axis}
|
\end{axis}
|
||||||
\node[anchor = base west, yshift = 0.3em] at (sim-2c.north west) {
|
\node[anchor = base west, yshift = 0.3em] at (sim-2c.north west) {
|
||||||
c:
|
c: rank 1 $\boldsymbol{\beta}$'s
|
||||||
};
|
};
|
||||||
|
|
||||||
\begin{axis}[
|
\begin{axis}[
|
||||||
|
@ -116,83 +116,21 @@
|
||||||
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2d-ising.csv};
|
\addplot[color = clpca] table[x = sample.size, y = dist.subspace.clpca] {aggr-2d-ising.csv};
|
||||||
\end{axis}
|
\end{axis}
|
||||||
\node[anchor = base west, yshift = 0.3em] at (sim-2d.north west) {
|
\node[anchor = base west, yshift = 0.3em] at (sim-2d.north west) {
|
||||||
d:
|
d: interaction constraints via $\boldsymbol{\Omega}$'s
|
||||||
};
|
};
|
||||||
|
|
||||||
% \begin{axis}[
|
|
||||||
% name=sim-1b,
|
|
||||||
% anchor=north west, at={(sim-2a.right of north east)}, xshift=0.1cm,
|
|
||||||
% xticklabel=\empty,
|
|
||||||
% ylabel near ticks, yticklabel pos=right
|
|
||||||
% ]
|
|
||||||
% \addplot[color = pca] table[x = sample.size, y = dist.subspace.pca] {aggr-1b-normal.csv};
|
|
||||||
% \addplot[color = hopca] table[x = sample.size, y = dist.subspace.hopca] {aggr-1b-normal.csv};
|
|
||||||
% \addplot[color = tsir] table[x = sample.size, y = dist.subspace.tsir] {aggr-1b-normal.csv};
|
|
||||||
% \addplot[color = mgcca] table[x = sample.size, y = dist.subspace.mgcca] {aggr-1b-normal.csv};
|
|
||||||
% \addplot[color = gmlm, line width=1pt] table[x = sample.size, y = dist.subspace.gmlm] {aggr-1b-normal.csv};
|
|
||||||
% \end{axis}
|
|
||||||
% \node[anchor = base west, yshift = 0.3em] at (sim-1b.north west) {
|
|
||||||
% b: cubic dependence on $y$
|
|
||||||
% };
|
|
||||||
|
|
||||||
% \begin{axis}[
|
|
||||||
% name=sim-1c,
|
|
||||||
% anchor=north west, at={(sim-2a.below south west)}, yshift=-.8em,
|
|
||||||
% xticklabel=\empty
|
|
||||||
% ]
|
|
||||||
% \addplot[color = pca] table[x = sample.size, y = dist.subspace.pca] {aggr-1c-normal.csv};
|
|
||||||
% \addplot[color = hopca] table[x = sample.size, y = dist.subspace.hopca] {aggr-1c-normal.csv};
|
|
||||||
% \addplot[color = tsir] table[x = sample.size, y = dist.subspace.tsir] {aggr-1c-normal.csv};
|
|
||||||
% \addplot[color = mgcca] table[x = sample.size, y = dist.subspace.mgcca] {aggr-1c-normal.csv};
|
|
||||||
% \addplot[color = gmlm, line width=1pt] table[x = sample.size, y = dist.subspace.gmlm] {aggr-1c-normal.csv};
|
|
||||||
% \end{axis}
|
|
||||||
% \node[anchor = base west, yshift = 0.3em] at (sim-1c.north west) {
|
|
||||||
% c: rank $1$ $\boldsymbol{\beta}$'s
|
|
||||||
% };
|
|
||||||
|
|
||||||
% \begin{axis}[
|
|
||||||
% name=sim-1d,
|
|
||||||
% anchor=north west, at={(sim-1c.right of north east)}, xshift=0.1cm,
|
|
||||||
% ylabel near ticks, yticklabel pos=right
|
|
||||||
% ]
|
|
||||||
% \addplot[color = pca] table[x = sample.size, y = dist.subspace.pca] {aggr-1d-normal.csv};
|
|
||||||
% \addplot[color = hopca] table[x = sample.size, y = dist.subspace.hopca] {aggr-1d-normal.csv};
|
|
||||||
% \addplot[color = tsir] table[x = sample.size, y = dist.subspace.tsir] {aggr-1d-normal.csv};
|
|
||||||
% \addplot[color = mgcca] table[x = sample.size, y = dist.subspace.mgcca] {aggr-1d-normal.csv};
|
|
||||||
% \addplot[color = gmlm, line width=1pt] table[x = sample.size, y = dist.subspace.gmlm] {aggr-1d-normal.csv};
|
|
||||||
% \end{axis}
|
|
||||||
% \node[anchor = base west, yshift = 0.3em] at (sim-1d.north west) {
|
|
||||||
% d: tri-diagonal $\boldsymbol{\Omega}$'s
|
|
||||||
% };
|
|
||||||
|
|
||||||
% \begin{axis}[
|
|
||||||
% name=sim-1e,
|
|
||||||
% anchor=north west, at={(sim-1c.below south west)}, yshift=-.8em
|
|
||||||
% ]
|
|
||||||
% \addplot[color = pca] table[x = sample.size, y = dist.subspace.pca] {aggr-1e-normal.csv};
|
|
||||||
% \addplot[color = hopca] table[x = sample.size, y = dist.subspace.hopca] {aggr-1e-normal.csv};
|
|
||||||
% \addplot[color = tsir] table[x = sample.size, y = dist.subspace.tsir] {aggr-1e-normal.csv};
|
|
||||||
% \addplot[color = mgcca] table[x = sample.size, y = dist.subspace.mgcca] {aggr-1e-normal.csv};
|
|
||||||
% \addplot[color = gmlm, line width=1pt] table[x = sample.size, y = dist.subspace.gmlm] {aggr-1e-normal.csv};
|
|
||||||
% \end{axis}
|
|
||||||
% \node[anchor = base west, yshift = 0.3em] at (sim-1e.north west) {
|
|
||||||
% e: missspecified
|
|
||||||
% };
|
|
||||||
|
|
||||||
|
|
||||||
\matrix[anchor = west] at (sim-2a.right of east) {
|
|
||||||
\draw[color=gmlm, legend entry, line width=1pt] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {GMLM}; \\
|
|
||||||
\draw[color=tsir, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {TSIR}; \\
|
|
||||||
\draw[color=mgcca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {MGCCA}; \\
|
|
||||||
\draw[color=hopca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {HOPCA}; \\
|
|
||||||
\draw[color=pca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {PCA}; \\
|
|
||||||
\draw[color=lpca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {LPCA}; \\
|
|
||||||
\draw[color=clpca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {CLPCA}; \\
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
\node[anchor = north] at (current bounding box.south) {Sample Size $n$};
|
\node[anchor = north] at (current bounding box.south) {Sample Size $n$};
|
||||||
|
|
||||||
|
\matrix[anchor = north] at (current bounding box.south) {
|
||||||
|
\draw[color=gmlm, legend entry, line width=1pt] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {GMLM}; &
|
||||||
|
\draw[color=tsir, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {TSIR}; &
|
||||||
|
\draw[color=mgcca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {MGCCA}; &
|
||||||
|
\draw[color=hopca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {HOPCA}; &
|
||||||
|
\draw[color=pca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {PCA}; &
|
||||||
|
\draw[color=lpca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {LPCA}; &
|
||||||
|
\draw[color=clpca, legend entry] plot coordinates {(0, 0) (.3, 0) (.6, 0)}; & \node[anchor=west] {CLPCA}; \\
|
||||||
|
};
|
||||||
|
|
||||||
\node[anchor = south, rotate = 90] at (current bounding box.west) {Subspace Distance $d(\boldsymbol{B}, \hat{\boldsymbol{B}})$};
|
\node[anchor = south, rotate = 90] at (current bounding box.west) {Subspace Distance $d(\boldsymbol{B}, \hat{\boldsymbol{B}})$};
|
||||||
\node[anchor = south, rotate = 270] at (current bounding box.east) {\phantom{Subspace Distance $d(\boldsymbol{B}, \hat{\boldsymbol{B}})$}};
|
\node[anchor = south, rotate = 270] at (current bounding box.east) {\phantom{Subspace Distance $d(\boldsymbol{B}, \hat{\boldsymbol{B}})$}};
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
%%% R code to generate the input data files from corresponding simulation logs
|
||||||
|
% R> setwd("~/Work/tensorPredictors")
|
||||||
|
% R>
|
||||||
|
% R> sim <- read.csv("sim/sim-tsir-20231123T1155.csv")
|
||||||
|
% R>
|
||||||
|
% R> aggr <- aggregate(sim[startsWith(names(sim), "dist")], sim[c("rho", "order", "beta.version")], mean)
|
||||||
|
% R>
|
||||||
|
% R> write.table(aggr, file = "LaTeX/plots/aggr-tsir.csv", row.names = FALSE, quote = FALSE)
|
||||||
|
\documentclass[border=0cm]{standalone}
|
||||||
|
|
||||||
|
\usepackage{tikz}
|
||||||
|
\usepackage{pgfplots}
|
||||||
|
\usepackage{bm}
|
||||||
|
|
||||||
|
\definecolor{gmlm}{RGB}{0,0,0}
|
||||||
|
% \definecolor{mgcca}{RGB}{86,180,233}
|
||||||
|
\definecolor{tsir}{RGB}{0,158,115}
|
||||||
|
\definecolor{sir}{RGB}{86,180,233}
|
||||||
|
% \definecolor{hopca}{RGB}{230,159,0}
|
||||||
|
% \definecolor{pca}{RGB}{240,228,66}
|
||||||
|
% \definecolor{lpca}{RGB}{0,114,178}
|
||||||
|
% \definecolor{clpca}{RGB}{213,94,0}
|
||||||
|
|
||||||
|
\pgfplotsset{
|
||||||
|
compat=newest,
|
||||||
|
every axis/.style={
|
||||||
|
% grid,
|
||||||
|
% grid style={gray, dotted},
|
||||||
|
3d box = complete*,
|
||||||
|
xtick = {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8},
|
||||||
|
ytick = {2, 3, 4},
|
||||||
|
ymin = 1.5, ymax = 4.5,
|
||||||
|
ztick = {0.2, 0.4, 0.6, 0.8, 1},
|
||||||
|
zmin = 0, zmax = 1,
|
||||||
|
xlabel = $\rho$,
|
||||||
|
ylabel = $r$
|
||||||
|
}
|
||||||
|
}
|
||||||
|
\tikzset{
|
||||||
|
legend entry/.style={
|
||||||
|
mark = *,
|
||||||
|
mark size = 1pt,
|
||||||
|
mark indices = {2},
|
||||||
|
line width=0.8pt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
\newcommand{\addarea}[4]{
|
||||||
|
\addplot3[
|
||||||
|
fill = #1,
|
||||||
|
draw = none,
|
||||||
|
fill opacity = 0.65,
|
||||||
|
y filter/.append expression={
|
||||||
|
\thisrow{order} == #3 && \thisrow{beta.version} == #2 ? #4 : NaN
|
||||||
|
}
|
||||||
|
] table [x = rho, y expr = NaN, z = dist.subspace.#1] {aggr-tsir.csv} \closedcycle;
|
||||||
|
\addplot3[
|
||||||
|
color = #1,
|
||||||
|
line width = 0.8pt,
|
||||||
|
y filter/.append expression={
|
||||||
|
\thisrow{order} == #3 && \thisrow{beta.version} == #2 ? #4 : NaN
|
||||||
|
}
|
||||||
|
] table [x = rho, y expr = NaN, z = dist.subspace.#1] {aggr-tsir.csv};
|
||||||
|
}
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
\begin{tikzpicture}[>=latex]
|
||||||
|
|
||||||
|
\begin{axis}[
|
||||||
|
name=v1
|
||||||
|
]
|
||||||
|
\addarea{sir}{1}{4}{4.05};
|
||||||
|
\addarea{tsir}{1}{4}{4.0};
|
||||||
|
\addarea{gmlm}{1}{4}{3.95};
|
||||||
|
|
||||||
|
\addarea{sir}{1}{3}{3.05};
|
||||||
|
\addarea{tsir}{1}{3}{3.0};
|
||||||
|
\addarea{gmlm}{1}{3}{2.95};
|
||||||
|
|
||||||
|
\addarea{sir}{1}{2}{2.05};
|
||||||
|
\addarea{tsir}{1}{2}{2.0};
|
||||||
|
\addarea{gmlm}{1}{2}{1.95};
|
||||||
|
\end{axis}
|
||||||
|
\node[anchor = south] at (v1.north) {V1};
|
||||||
|
|
||||||
|
\begin{axis}[
|
||||||
|
name=v2,
|
||||||
|
xshift = 8cm
|
||||||
|
]
|
||||||
|
\addarea{sir}{2}{4}{4.05};
|
||||||
|
\addarea{tsir}{2}{4}{4.0};
|
||||||
|
\addarea{gmlm}{2}{4}{3.95};
|
||||||
|
|
||||||
|
\addarea{sir}{2}{3}{3.05};
|
||||||
|
\addarea{tsir}{2}{3}{3.0};
|
||||||
|
\addarea{gmlm}{2}{3}{2.95};
|
||||||
|
|
||||||
|
\addarea{sir}{2}{2}{2.05};
|
||||||
|
\addarea{tsir}{2}{2}{2.0};
|
||||||
|
\addarea{gmlm}{2}{2}{1.95};
|
||||||
|
\end{axis}
|
||||||
|
\node[anchor = south] at (v2.north) {V2};
|
||||||
|
|
||||||
|
\matrix[anchor = north] at (current bounding box.south) {
|
||||||
|
\fill[color = gmlm, opacity = 0.65] (0, -.1) rectangle (.6, .1); \draw[color=gmlm, line width = 0.8pt] (0, .1) -- (.6, .1); & \node[anchor=west] {GMLM}; &
|
||||||
|
\fill[color = tsir, opacity = 0.65] (0, -.1) rectangle (.6, .1); \draw[color=tsir, line width = 0.8pt] (0, .1) -- (.6, .1); & \node[anchor=west] {TSIR}; &
|
||||||
|
\fill[color = sir, opacity = 0.65] (0, -.1) rectangle (.6, .1); \draw[color=sir, line width = 0.8pt] (0, .1) -- (.6, .1); & \node[anchor=west] {SIR}; \\
|
||||||
|
};
|
||||||
|
|
||||||
|
\node[anchor = south, rotate = 90, yshift = 1.5em] at (v1.west) {Subspace Distance $d(\boldsymbol{B}, \widehat{\boldsymbol{B}})$};
|
||||||
|
\node[anchor = south, rotate = -90, yshift = 1.5em] at (v2.east) {\phantom{Subspace Distance $d(\boldsymbol{B}, \widehat{\boldsymbol{B}})$}};
|
||||||
|
|
||||||
|
\end{tikzpicture}
|
||||||
|
\end{document}
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Source dependency setup
|
||||||
|
INCLUDE = Rchess/inst/include/SchachHoernchen
|
||||||
|
SRC = $(notdir $(wildcard $(INCLUDE)/*.cpp))
|
||||||
|
OBJ = $(SRC:.cpp=.o)
|
||||||
|
|
||||||
|
# Compiler config
|
||||||
|
CC = g++
|
||||||
|
FLAGS = -I$(INCLUDE) -Wall -Wextra -Wpedantic -pedantic -pthread -O3 -march=native -mtune=native
|
||||||
|
CPPFLAGS = $(FLAGS) -std=c++17
|
||||||
|
LDFLAGS = $(FLAGS)
|
||||||
|
|
||||||
|
.PHONY: all clean
|
||||||
|
|
||||||
|
%.o: $(INCLUDE)/%.cpp
|
||||||
|
$(CC) $(CPPFLAGS) -c $< -o $(notdir $@)
|
||||||
|
|
||||||
|
pgn2fen.o: pgn2fen.cpp
|
||||||
|
$(CC) $(CPPFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
pgn2fen: pgn2fen.o $(OBJ)
|
||||||
|
$(CC) $(LDFLAGS) -o $@ $^
|
||||||
|
|
||||||
|
all: pgn2fen
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm -f *.out *.o *.h.gch pgn2fen
|
|
@ -0,0 +1,15 @@
|
||||||
|
Package: Rchess
|
||||||
|
Type: Package
|
||||||
|
Title: Wrapper to the SchachHoernchen engine
|
||||||
|
Version: 1.0
|
||||||
|
Date: 2022-06-12
|
||||||
|
Author: loki
|
||||||
|
Maintainer: Your Name <your@email.com>
|
||||||
|
Description: Basic wrapper to the underlying C++ code of the SchachHoernchen
|
||||||
|
engine. Primarely intended to provide chess specific data processing.
|
||||||
|
Encoding: UTF-8
|
||||||
|
License: GPL (>= 2)
|
||||||
|
Imports: Rcpp (>= 1.0.8)
|
||||||
|
LinkingTo: Rcpp
|
||||||
|
SystemRequirements: c++17
|
||||||
|
RoxygenNote: 7.2.0
|
|
@ -0,0 +1,4 @@
|
||||||
|
useDynLib(Rchess, .registration=TRUE)
|
||||||
|
importFrom(Rcpp, evalCpp)
|
||||||
|
exportPattern("^[[:alpha:]]+")
|
||||||
|
S3method(print, board)
|
|
@ -0,0 +1,124 @@
|
||||||
|
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
|
||||||
|
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
|
||||||
|
|
||||||
|
#' Human Crafted Evaluation
|
||||||
|
HCE <- function(positions) {
|
||||||
|
.Call(`_Rchess_HCE`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Given a FEN (position) determines if its whites turn
|
||||||
|
isWhiteTurn <- function(positions) {
|
||||||
|
.Call(`_Rchess_isWhiteTurn`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if current side to move is in check
|
||||||
|
isCheck <- function(positions) {
|
||||||
|
.Call(`_Rchess_isCheck`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if the current position is a quiet position (no piece is attacked)
|
||||||
|
isQuiet <- function(positions) {
|
||||||
|
.Call(`_Rchess_isQuiet`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if position is terminal
|
||||||
|
#'
|
||||||
|
#' Checks if the position is a terminal position, meaning if the game ended
|
||||||
|
#' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||||
|
#' checked, therefore a seperate game history is required which the board
|
||||||
|
#' does NOT track.
|
||||||
|
#'
|
||||||
|
isTerminal <- function(positions) {
|
||||||
|
.Call(`_Rchess_isTerminal`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Check if checkmate is possible by material on the board
|
||||||
|
#'
|
||||||
|
#' Checks if there is sufficient mating material on the board, meaning if it
|
||||||
|
#' possible for any side to deliver a check mate. More specifically, it
|
||||||
|
#' checks if the pieces on the board are KK, KNK or KBK.
|
||||||
|
#'
|
||||||
|
isInsufficient <- function(positions) {
|
||||||
|
.Call(`_Rchess_isInsufficient`, positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
|
||||||
|
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
||||||
|
#' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||||
|
#'
|
||||||
|
data.gen <- function(file, sample_size, score_min = -5.0, score_max = +5.0, quiet = FALSE, min_ply_count = 10L, white_only = TRUE) {
|
||||||
|
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max, quiet, min_ply_count, white_only)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Human Crafted Evaluation
|
||||||
|
eval.psqt <- function(positions, psqt, pawn_structure = FALSE, eval_rooks = FALSE, eval_king = FALSE) {
|
||||||
|
.Call(`_Rchess_eval_psqt`, positions, psqt, pawn_structure, eval_rooks, eval_king)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
|
||||||
|
fen2int <- function(boards) {
|
||||||
|
.Call(`_Rchess_fen2int`, boards)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Reads lines from a text file with recycling.
|
||||||
|
#'
|
||||||
|
read.cyclic <- function(file, nrows = 1000L, skip = 100L, start = 1L, line_len = 64L) {
|
||||||
|
.Call(`_Rchess_read_cyclic`, file, nrows, skip, start, line_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Samples a legal move from a given position
|
||||||
|
sample.move <- function(pos) {
|
||||||
|
.Call(`_Rchess_sample_move`, pos)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Samples a random FEN (position) by applying `ply` random moves to the start
|
||||||
|
#' position.
|
||||||
|
#'
|
||||||
|
#' @param nr number of positions to sample
|
||||||
|
#' @param min_depth minimum number of random ply's to generate random positions
|
||||||
|
#' @param max_depth maximum number of random ply's to generate random positions
|
||||||
|
sample.fen <- function(nr, min_depth = 4L, max_depth = 20L) {
|
||||||
|
.Call(`_Rchess_sample_fen`, nr, min_depth, max_depth)
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Converts a FEN string to a Board (position) or return the current internal state
|
||||||
|
board <- function(fen = "") {
|
||||||
|
.Call(`_Rchess_board`, fen)
|
||||||
|
}
|
||||||
|
|
||||||
|
print.board <- function(fen = "") {
|
||||||
|
invisible(.Call(`_Rchess_print_board`, fen))
|
||||||
|
}
|
||||||
|
|
||||||
|
print.moves <- function(fen = "") {
|
||||||
|
invisible(.Call(`_Rchess_print_moves`, fen))
|
||||||
|
}
|
||||||
|
|
||||||
|
print.bitboards <- function(fen = "") {
|
||||||
|
invisible(.Call(`_Rchess_print_bitboards`, fen))
|
||||||
|
}
|
||||||
|
|
||||||
|
position <- function(pos, moves, san = FALSE) {
|
||||||
|
.Call(`_Rchess_position`, pos, moves, san)
|
||||||
|
}
|
||||||
|
|
||||||
|
perft <- function(depth = 6L) {
|
||||||
|
invisible(.Call(`_Rchess_perft`, depth))
|
||||||
|
}
|
||||||
|
|
||||||
|
go <- function(depth = 6L) {
|
||||||
|
.Call(`_Rchess_go`, depth)
|
||||||
|
}
|
||||||
|
|
||||||
|
ucinewgame <- function() {
|
||||||
|
invisible(.Call(`_Rchess_ucinewgame`))
|
||||||
|
}
|
||||||
|
|
||||||
|
.onLoad <- function(libname, pkgname) {
|
||||||
|
invisible(.Call(`_Rchess_onLoad`, libname, pkgname))
|
||||||
|
}
|
||||||
|
|
||||||
|
.onUnload <- function(libpath) {
|
||||||
|
invisible(.Call(`_Rchess_onUnload`, libpath))
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
#ifndef INCLUDE_GUARD_RCHESS_TYPES
|
||||||
|
#define INCLUDE_GUARD_RCHESS_TYPES
|
||||||
|
|
||||||
|
#include <RcppCommon.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
#include "SchachHoernchen/uci.h"
|
||||||
|
|
||||||
|
namespace Rcpp {
|
||||||
|
template <> Move as(SEXP);
|
||||||
|
template <> SEXP wrap(const Move&);
|
||||||
|
|
||||||
|
template <> Board as(SEXP);
|
||||||
|
template <> SEXP wrap(const Board&);
|
||||||
|
} /* namespace Rcpp */
|
||||||
|
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
namespace Rcpp {
|
||||||
|
// Convert a coordinate encoded move string into a Move
|
||||||
|
template <>
|
||||||
|
Move as(SEXP obj) {
|
||||||
|
// parse (and validate) the move
|
||||||
|
bool parseError = false;
|
||||||
|
Move move = UCI::parseMove(Rcpp::as<std::string>(obj), parseError);
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Error parsing move");
|
||||||
|
}
|
||||||
|
return move;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert a Move into an `R` character
|
||||||
|
template <>
|
||||||
|
SEXP wrap(const Move& move) {
|
||||||
|
return Rcpp::CharacterVector::create(UCI::formatMove(move));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert a FEN string to a board
|
||||||
|
template <>
|
||||||
|
Board as(SEXP obj) {
|
||||||
|
bool parseError = false;
|
||||||
|
Board board;
|
||||||
|
std::string fen = Rcpp::as<std::string>(obj);
|
||||||
|
if (fen != "startpos") {
|
||||||
|
board.init(fen, parseError);
|
||||||
|
}
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Parsing FEN failed");
|
||||||
|
}
|
||||||
|
return board;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert board to `R` board class (as character FEN string)
|
||||||
|
template <>
|
||||||
|
SEXP wrap(const Board& board) {
|
||||||
|
auto obj = Rcpp::CharacterVector::create(board.fen());
|
||||||
|
obj.attr("class") = "board";
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert a character vector or list to a vector of Boards
|
||||||
|
template <>
|
||||||
|
std::vector<Board> as(SEXP obj) {
|
||||||
|
// Convert SEXP to be a vector of string
|
||||||
|
auto fens = Rcpp::as<std::vector<std::string>>(obj);
|
||||||
|
// Try to parse every string as a Board from a FEN
|
||||||
|
std::vector<Board> boards(fens.size());
|
||||||
|
for (int i = 0; i < fens.size(); ++i) {
|
||||||
|
bool parseError = false;
|
||||||
|
if (fens[i] != "startpos") {
|
||||||
|
boards[i].init(fens[i], parseError);
|
||||||
|
}
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Parsing FEN nr. %d failed", i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boards;
|
||||||
|
}
|
||||||
|
} /* namespace Rcpp */
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_RCHESS_TYPES */
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,224 @@
|
||||||
|
#ifndef INCLUDE_GUARD_BOARD_H
|
||||||
|
#define INCLUDE_GUARD_BOARD_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <array>
|
||||||
|
#include <functional> // for std::hash
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
|
||||||
|
// Forward declarations
|
||||||
|
class Move;
|
||||||
|
class MoveList;
|
||||||
|
|
||||||
|
// pseudo-random-number-generator to fill Zobrist lookup hash table
|
||||||
|
static constexpr u64 rot(u64 val, int shift) {
|
||||||
|
return (val << shift) | (val >> (64 - shift));
|
||||||
|
}
|
||||||
|
static constexpr u64 random(u64& a, u64& b, u64& c, u64& d) {
|
||||||
|
u64 e = a - rot(b, 7);
|
||||||
|
a = b ^ rot(b, 13);
|
||||||
|
b = c + rot(d, 37);
|
||||||
|
c = d + e;
|
||||||
|
d = e + a;
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* BitBoard indexing;
|
||||||
|
* @verbatim
|
||||||
|
* rank rankIndex
|
||||||
|
* | index |
|
||||||
|
* v +-------------------------+ v
|
||||||
|
* 8 | 0 1 2 3 4 5 6 7 | 0
|
||||||
|
* 7 | 8 9 10 11 12 13 14 15 | 1
|
||||||
|
* 6 | 16 17 18 19 20 21 22 23 | 2
|
||||||
|
* 5 | 24 25 26 27 28 29 30 31 | 3
|
||||||
|
* 4 | 32 33 34 35 36 37 38 39 | 4
|
||||||
|
* 3 | 40 41 42 43 44 45 46 47 | 5
|
||||||
|
* 2 | 48 49 50 51 52 53 54 55 | 6
|
||||||
|
* 1 | 56 57 58 59 60 61 62 63 | 7
|
||||||
|
* +-------------------------+
|
||||||
|
* a b c d e f g h <- file
|
||||||
|
* 0 1 2 3 4 5 6 7 <- fileIndex
|
||||||
|
* @endverbatim
|
||||||
|
*/
|
||||||
|
class Board {
|
||||||
|
public:
|
||||||
|
Board()
|
||||||
|
: _castle{true, true, true, true}
|
||||||
|
, _enPassant{static_cast<Index>(64)}
|
||||||
|
, _halfMoveClock{0}
|
||||||
|
, _plyCount{1}
|
||||||
|
, _bitBoard{
|
||||||
|
0xFFFF000000000000, // white
|
||||||
|
0x000000000000FFFF, // black
|
||||||
|
0x00FF00000000FF00, // pawns
|
||||||
|
0x4200000000000042, // knights
|
||||||
|
0x2400000000000024, // bishops
|
||||||
|
0x8100000000000081, // rooks
|
||||||
|
0x0800000000000008, // queens
|
||||||
|
0x1000000000000010, // kings
|
||||||
|
} {
|
||||||
|
_hash = rehash(); // This makes it non default constructible
|
||||||
|
}; // but the main performance requirement is in the
|
||||||
|
// copy constructor and assignment operator
|
||||||
|
|
||||||
|
// copy constructor
|
||||||
|
Board(const Board& board)
|
||||||
|
: _castle{board._castle}
|
||||||
|
, _enPassant{board._enPassant}
|
||||||
|
, _halfMoveClock{board._halfMoveClock}
|
||||||
|
, _plyCount{board._plyCount}
|
||||||
|
, _hash{board._hash}
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
_bitBoard[i] = board._bitBoard[i];
|
||||||
|
};
|
||||||
|
// copy assignment operator
|
||||||
|
Board& operator=(const Board& board) = default;
|
||||||
|
|
||||||
|
// accessor functions for not directly given values
|
||||||
|
enum piece piece(const Index) const;
|
||||||
|
enum piece color(const Index) const;
|
||||||
|
// accessor functions for internal states (private, only altered by Board)
|
||||||
|
u64 bb(enum piece piece) const { return _bitBoard[piece]; }
|
||||||
|
bool castleRight(enum piece color, enum location side) const;
|
||||||
|
Index enPassant() const { return _enPassant; }
|
||||||
|
unsigned halfMoveClock() const { return _halfMoveClock; }
|
||||||
|
unsigned plyCount() const { return _plyCount; }
|
||||||
|
u64 knightMoves(Index from) const { return knightMoveLookup[from]; }
|
||||||
|
u64 kingMoves(Index from) const { return kingMoveLookup[from]; }
|
||||||
|
|
||||||
|
// Ply Count interpretation of who's move it is
|
||||||
|
bool isWhiteTurn() const { return static_cast<bool>(_plyCount % 2); }
|
||||||
|
enum piece sideToMove() const {
|
||||||
|
return static_cast<enum piece>(!(_plyCount % 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Formatting functions of pieces, squares, board, ...?
|
||||||
|
char formatPiece(Index fileIndex, Index rankIndex) const;
|
||||||
|
char formatPiece(Index squareIndex) const;
|
||||||
|
std::string format() const;
|
||||||
|
|
||||||
|
// Get board hash (Zobrist-hash), for complete computation see `rehash`
|
||||||
|
u64 hash() const { return _hash; };
|
||||||
|
// Get FEN string from board
|
||||||
|
std::string fen(const bool withCounts = true) const;
|
||||||
|
// Set board state given by FEN string
|
||||||
|
void init(const std::string& fen, bool& parseError);
|
||||||
|
// Complete (new) calculation of Zobrist hash for the current board position
|
||||||
|
u64 rehash() const;
|
||||||
|
|
||||||
|
// Validate a move (all moves from input checked before passed to `make`)
|
||||||
|
// and returns a legal (extended with all move info's) or non-move if case of
|
||||||
|
// illegal moves.
|
||||||
|
Move isLegal(const Move) const;
|
||||||
|
// Make a given move (assumed legal, otherwise undefined behavior)
|
||||||
|
void make(const Move); // passing move by value
|
||||||
|
|
||||||
|
// Generate all legal moves
|
||||||
|
void moves(MoveList&, bool = false) const;
|
||||||
|
void moves(MoveList&, const enum piece, bool = false) const;
|
||||||
|
|
||||||
|
// Check if current side to move is in check
|
||||||
|
bool isCheck() const;
|
||||||
|
// Check if the current position is a quiet position (no piece is attacked)
|
||||||
|
bool isQuiet() const;
|
||||||
|
// Checks if the position is a terminal position, meaning if the game ended
|
||||||
|
// by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||||
|
// checked, therefore a seperate game history is required which the board
|
||||||
|
// does NOT track.
|
||||||
|
bool isTerminal() const;
|
||||||
|
// Checks if there is sufficient mating material on the board, meaning if it
|
||||||
|
// possible for any side to deliver a check mate. More specifically, it
|
||||||
|
// checks if the pieces on the board are KK, KNK or KBK.
|
||||||
|
bool isInsufficient() const;
|
||||||
|
// TODO: validate further possibilities like KBKB for two same color bishobs?!
|
||||||
|
|
||||||
|
// Board evaluation, gives a heuristic score of the current board position
|
||||||
|
// viewed from whites side
|
||||||
|
Score evalMaterial(enum piece) const;
|
||||||
|
Score evalPawns(enum piece) const;
|
||||||
|
Score evalKingSafety(enum piece) const;
|
||||||
|
Score evalRooks(enum piece) const;
|
||||||
|
Score evaluate() const;
|
||||||
|
// Static Exchange Evaluation
|
||||||
|
Score see(const Move) const;
|
||||||
|
|
||||||
|
// Subroutines for move generation, ...
|
||||||
|
u64 attacks(const enum piece, const u64) const;
|
||||||
|
u64 pinned(const enum piece, const u64, const u64) const;
|
||||||
|
u64 checkers(const enum piece, const u64, const u64) const;
|
||||||
|
|
||||||
|
// Equality operator used for comparing boards (exact equality in contrast
|
||||||
|
// to hash equality)
|
||||||
|
friend bool operator==(const Board& lhs, const Board& rhs) {
|
||||||
|
bool isEqual = true;
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
isEqual &= (lhs._bitBoard[i] == rhs._bitBoard[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return isEqual
|
||||||
|
&& (lhs._castle.whiteKingSide == rhs._castle.whiteKingSide)
|
||||||
|
&& (lhs._castle.whiteQueenSide == rhs._castle.whiteQueenSide)
|
||||||
|
&& (lhs._castle.blackKingSide == rhs._castle.blackKingSide)
|
||||||
|
&& (lhs._castle.blackQueenSide == rhs._castle.blackQueenSide)
|
||||||
|
&& (lhs._enPassant == rhs._enPassant);
|
||||||
|
}
|
||||||
|
friend bool operator!=(const Board& lhs, const Board& rhs) {
|
||||||
|
return !(lhs == rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct {
|
||||||
|
bool whiteKingSide : 1;
|
||||||
|
bool whiteQueenSide : 1;
|
||||||
|
bool blackKingSide : 1;
|
||||||
|
bool blackQueenSide : 1;
|
||||||
|
} _castle;
|
||||||
|
Index _enPassant;
|
||||||
|
unsigned _halfMoveClock;
|
||||||
|
unsigned _plyCount;
|
||||||
|
u64 _hash;
|
||||||
|
u64 _bitBoard[8];
|
||||||
|
|
||||||
|
// Lookup table for "random" hash value for piece-square, ... used in
|
||||||
|
// Zobrist board hashing.
|
||||||
|
static constexpr std::array<u64, 781> hashTable = []() {
|
||||||
|
std::array<u64, 781> table{};
|
||||||
|
|
||||||
|
u64 a = 0xABCD1729, b = 0, c = 0, d = 0;
|
||||||
|
for (int i = -11; i < 781; i++) {
|
||||||
|
u64 rand = random(a, b, c, d);
|
||||||
|
if (i >= 0) table[i] = rand;
|
||||||
|
}
|
||||||
|
return table;
|
||||||
|
}();
|
||||||
|
// Lookup helper for Zobrist hash lookup table
|
||||||
|
// hash index for 6 pieces of both colors for all squares
|
||||||
|
constexpr Index hashIndex(enum piece piece, enum piece color, Index sq) const {
|
||||||
|
return 12 * sq + (6 * (color == black)) + piece - 2;
|
||||||
|
}
|
||||||
|
// hash index the e.p. target square (file)
|
||||||
|
constexpr Index hashIndex(Index sq) const {
|
||||||
|
return 768 + (sq % 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
enum piece sqPiece(const u64) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inject specialization hash<Board> into standard name space
|
||||||
|
namespace std {
|
||||||
|
template <>
|
||||||
|
struct hash<Board> {
|
||||||
|
using result_type = u64;
|
||||||
|
|
||||||
|
u64 operator()(const Board& board) const noexcept {
|
||||||
|
return board.hash();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_BOARD_H */
|
|
@ -0,0 +1,150 @@
|
||||||
|
#ifndef UCI_GUARD_HASHTABLE_H
|
||||||
|
#define UCI_GUARD_HASHTABLE_H
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include <utility> // std::pair
|
||||||
|
#include <functional> // std::hash
|
||||||
|
#include <type_traits>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
// Default policy; replace everything with the newest entry
|
||||||
|
template <typename T>
|
||||||
|
struct policy {
|
||||||
|
constexpr bool operator()(const T&, const T&) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename Key,
|
||||||
|
typename Entry,
|
||||||
|
typename Policy = policy<Entry>,
|
||||||
|
typename Hashing = std::hash<Key>
|
||||||
|
>
|
||||||
|
class HashTable {
|
||||||
|
public:
|
||||||
|
using Hash = typename Hashing::result_type;
|
||||||
|
using Line = typename std::pair<Hash, Entry>;
|
||||||
|
|
||||||
|
HashTable() : _size{0}, _occupied{0}, _table{nullptr}, _hash{}, _policy{} { };
|
||||||
|
HashTable(Index size)
|
||||||
|
: _size{size}
|
||||||
|
, _occupied{0}
|
||||||
|
, _table{new (std::nothrow) Line[size]}
|
||||||
|
, _hash{}
|
||||||
|
, _policy{}
|
||||||
|
{
|
||||||
|
if (_table) {
|
||||||
|
for (Index i = 0; i < _size; i++) {
|
||||||
|
_table[i].first = static_cast<Hash>(0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_size = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
~HashTable() { if (_table) { delete[] _table; } }
|
||||||
|
|
||||||
|
Index size() const { return _size; }
|
||||||
|
Index used() const { return 1000 * _occupied / _size; }
|
||||||
|
|
||||||
|
bool reserve(Index size) noexcept {
|
||||||
|
_occupied = 0;
|
||||||
|
if (size != _size) {
|
||||||
|
if (_table) { delete[] _table; }
|
||||||
|
_table = new (std::nothrow) Line[size];
|
||||||
|
}
|
||||||
|
if (_table) {
|
||||||
|
_size = size;
|
||||||
|
for (Index i = 0; i < _size; i++) {
|
||||||
|
_table[i].first = static_cast<Hash>(0);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
_size = 0;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() noexcept {
|
||||||
|
_occupied = 0;
|
||||||
|
for (Index i = 0; i < _size; i++) {
|
||||||
|
_table[i].first = static_cast<Hash>(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void erase() {
|
||||||
|
_occupied = 0;
|
||||||
|
_size = 0;
|
||||||
|
if (_table) { delete[] _table; }
|
||||||
|
_table = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void insert(const Key& key, const Entry& entry) {
|
||||||
|
assert(0 < _size);
|
||||||
|
|
||||||
|
Hash hash = _hash(key);
|
||||||
|
Index index = static_cast<Index>(hash) % _size;
|
||||||
|
if (_table[index].first == static_cast<Hash>(0)) {
|
||||||
|
_occupied++;
|
||||||
|
_table[index] = Line(hash, entry);
|
||||||
|
}
|
||||||
|
if (_policy(_table[index].second, entry)) {
|
||||||
|
_table[index] = Line(hash, entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Iter {
|
||||||
|
using iterator_category = std::forward_iterator_tag;
|
||||||
|
using difference_type = std::ptrdiff_t;
|
||||||
|
using value_type = Line;
|
||||||
|
using pointer = value_type*;
|
||||||
|
using reference = Entry&;
|
||||||
|
|
||||||
|
Iter(pointer ptr) : _ptr{ptr} { };
|
||||||
|
|
||||||
|
const Hash& hash() const { return _ptr->first; };
|
||||||
|
const Entry& value() const { return _ptr->second; };
|
||||||
|
|
||||||
|
// dereference operators
|
||||||
|
reference operator*() { return _ptr->second; };
|
||||||
|
pointer operator->() { return _ptr; };
|
||||||
|
|
||||||
|
// comparison operators
|
||||||
|
friend bool operator==(const Iter& lhs, const Iter& rhs) {
|
||||||
|
return lhs._ptr == rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator!=(const Iter& lhs, const Iter& rhs) {
|
||||||
|
return lhs._ptr != rhs._ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
pointer _ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
Iter begin() noexcept { return Iter(_table); }
|
||||||
|
Iter end() noexcept { return Iter(_table + _size); }
|
||||||
|
|
||||||
|
Iter find(const Key& key) noexcept {
|
||||||
|
if (_size) {
|
||||||
|
Hash hash = _hash(key);
|
||||||
|
if (hash == static_cast<Hash>(0)) {
|
||||||
|
return Iter(_table + _size);
|
||||||
|
}
|
||||||
|
Index index = static_cast<Index>(hash) % _size;
|
||||||
|
if (_table[index].first == hash) {
|
||||||
|
return Iter(_table + index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Iter(_table + _size);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Index _size;
|
||||||
|
Index _occupied;
|
||||||
|
Line* _table;
|
||||||
|
Hashing _hash;
|
||||||
|
Policy _policy;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* UCI_GUARD_HASHTABLE_H */
|
|
@ -0,0 +1,305 @@
|
||||||
|
#ifndef INCLUDE_GUARD_MOVE_H
|
||||||
|
#define INCLUDE_GUARD_MOVE_H
|
||||||
|
|
||||||
|
#include <iterator>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
#include "types.h"
|
||||||
|
|
||||||
|
class Move {
|
||||||
|
public:
|
||||||
|
using base_type = uint64_t;
|
||||||
|
|
||||||
|
private:
|
||||||
|
base_type _bits;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Move() = default; // required by MoveList
|
||||||
|
// Reverse of cast to base_type
|
||||||
|
Move(base_type bits) : _bits{bits} { };
|
||||||
|
|
||||||
|
explicit Move(Index from, Index to)
|
||||||
|
: Move(from, to, static_cast<enum piece>(0)) { };
|
||||||
|
explicit Move(Index from, Index to, enum piece promotion)
|
||||||
|
: Move(static_cast<enum piece>(0), static_cast<enum piece>(0), from, to,
|
||||||
|
static_cast<enum piece>(0), promotion) { };
|
||||||
|
explicit Move(enum piece color, enum piece piece, Index from, Index to,
|
||||||
|
enum piece victim)
|
||||||
|
: Move(color, piece, from, to, victim, static_cast<enum piece>(0)) { };
|
||||||
|
// General Move constructor
|
||||||
|
explicit Move(enum piece color, enum piece piece, Index from, Index to,
|
||||||
|
enum piece victim, enum piece promotion)
|
||||||
|
: _bits{static_cast<base_type>(
|
||||||
|
( color << 21)
|
||||||
|
| ( victim << 18)
|
||||||
|
| ( piece << 15)
|
||||||
|
| (promotion << 12)
|
||||||
|
| ( to << 6)
|
||||||
|
| from
|
||||||
|
)}
|
||||||
|
{
|
||||||
|
assert(((color == white)
|
||||||
|
|| (color == black)));
|
||||||
|
assert(((victim == knight)
|
||||||
|
|| (victim == bishop)
|
||||||
|
|| (victim == rook)
|
||||||
|
|| (victim == queen)
|
||||||
|
|| (victim == pawn)
|
||||||
|
|| !victim));
|
||||||
|
assert(((piece == knight)
|
||||||
|
|| (piece == bishop)
|
||||||
|
|| (piece == rook)
|
||||||
|
|| (piece == queen)
|
||||||
|
|| (piece == pawn)
|
||||||
|
|| (piece == king)
|
||||||
|
|| !piece));
|
||||||
|
assert(((promotion == knight)
|
||||||
|
|| (promotion == bishop)
|
||||||
|
|| (promotion == rook)
|
||||||
|
|| (promotion == queen)
|
||||||
|
|| !promotion));
|
||||||
|
assert(( to < 64));
|
||||||
|
assert((from < 64));
|
||||||
|
};
|
||||||
|
|
||||||
|
Index from() const { return _bits & 63U; }
|
||||||
|
Index to() const { return (_bits >> 6) & 63U; }
|
||||||
|
enum piece promote() const {
|
||||||
|
return static_cast<enum piece>((_bits >> 12) & 7U);
|
||||||
|
}
|
||||||
|
enum piece piece() const {
|
||||||
|
return static_cast<enum piece>((_bits >> 15) & 7U);
|
||||||
|
}
|
||||||
|
enum piece victim() const {
|
||||||
|
return static_cast<enum piece>((_bits >> 18) & 7U);
|
||||||
|
}
|
||||||
|
enum piece color() const {
|
||||||
|
return static_cast<enum piece>((_bits >> 21) & 7U);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t score() const {
|
||||||
|
return static_cast<uint32_t>(_bits >> 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr uint32_t killerScore[2] = { 4096U, 4095U };
|
||||||
|
void setScore(const uint32_t val) {
|
||||||
|
constexpr base_type mask = (1ULL << 22) - 1ULL;
|
||||||
|
_bits = (_bits & mask) | (static_cast<base_type>(val) << 32);
|
||||||
|
}
|
||||||
|
uint32_t calcScore() const {
|
||||||
|
const uint32_t mvv_lva = ((victim() << 3) + (7U - piece()));
|
||||||
|
const bool winning = victim() > piece();
|
||||||
|
// add 128 to ensure the PST values are positive
|
||||||
|
const Index t = color() == white ? to() : 63 - to();
|
||||||
|
const Index f = color() == white ? from() : 63 - from();
|
||||||
|
const uint32_t pst = PSQT[piece()][t] - PSQT[piece()][f] + 128;
|
||||||
|
|
||||||
|
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
|
||||||
|
}
|
||||||
|
// Allows to be cast to the base type (as number)
|
||||||
|
operator base_type() { return _bits; }
|
||||||
|
|
||||||
|
bool operator!() const { return this->from() == this->to(); }
|
||||||
|
// Comparison of <from><to>[<promotion>] part of the move. Everything else
|
||||||
|
// is redundent and allows simple parse move routine (use Board::isLegal
|
||||||
|
// for move legality test and augmentation)
|
||||||
|
bool operator==(const Move& rhs) const {
|
||||||
|
constexpr base_type mask = (1ULL << 22) - 1ULL;
|
||||||
|
return (_bits & mask) == (rhs._bits & mask);
|
||||||
|
}
|
||||||
|
bool operator!=(const Move& rhs) const {
|
||||||
|
constexpr base_type mask = (1ULL << 22) - 1ULL;
|
||||||
|
return (_bits & mask) != (rhs._bits & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// default sort order is decreasing (for scored moves the best move first)
|
||||||
|
friend bool operator<(const Move& lhs, const Move& rhs) {
|
||||||
|
return lhs._bits > rhs._bits;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MoveList {
|
||||||
|
private:
|
||||||
|
static constexpr unsigned _max_size = 256;
|
||||||
|
unsigned _size;
|
||||||
|
Move _moves[_max_size]; // apparently, 218 are the max nr. of moves
|
||||||
|
|
||||||
|
public:
|
||||||
|
MoveList() : _size{0} { };
|
||||||
|
|
||||||
|
// Copy Constructor
|
||||||
|
MoveList(const MoveList& moveList) : _size{moveList._size} {
|
||||||
|
assert(_size <= _max_size);
|
||||||
|
|
||||||
|
std::copy(moveList._moves, moveList._moves + _size, _moves);
|
||||||
|
}
|
||||||
|
// Copy Assignement Operator
|
||||||
|
MoveList& operator=(const MoveList&) = default;
|
||||||
|
|
||||||
|
unsigned size() const { return _size; };
|
||||||
|
bool empty() const { return !static_cast<bool>(_size); };
|
||||||
|
static constexpr unsigned max_size() { return _max_size; };
|
||||||
|
void clear() { _size = 0; };
|
||||||
|
Move* data() { return _moves; };
|
||||||
|
const Move* data() const { return _moves; };
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
void emplace_back(Args&&... args) {
|
||||||
|
assert(_size < _max_size);
|
||||||
|
|
||||||
|
_moves[_size++] = Move(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const Move& move) {
|
||||||
|
assert(_size < _max_size);
|
||||||
|
|
||||||
|
_moves[_size++] = move;
|
||||||
|
}
|
||||||
|
|
||||||
|
void append(const MoveList& list) {
|
||||||
|
assert((_size + list._size) < _max_size);
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < list._size; i++) {
|
||||||
|
_moves[_size++] = list._moves[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool contains(const Move& move) const {
|
||||||
|
for (unsigned i = 0; i < _size; i++) {
|
||||||
|
if (_moves[i] == move) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Searches for an occurence of `move` in the list and places the `move` as
|
||||||
|
* the first element in the list. If the provided `move` is not an element
|
||||||
|
* no opperation is performed and `false` will be returned.
|
||||||
|
*
|
||||||
|
* @param move a move to be searched for and placed as first element.
|
||||||
|
* @return `true` if `move` is found and placed as first element,
|
||||||
|
* `false` otherwise.
|
||||||
|
*/
|
||||||
|
bool move_front(const Move& move) {
|
||||||
|
for (unsigned i = 0; i < _size; i++) {
|
||||||
|
if (_moves[i] == move) {
|
||||||
|
std::swap(_moves[0], _moves[i]);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Move& operator[](unsigned i) { return _moves[i]; }
|
||||||
|
const Move& operator[](unsigned i) const { return _moves[i]; }
|
||||||
|
|
||||||
|
struct MoveIter {
|
||||||
|
using iterator_category = std::forward_iterator_tag;
|
||||||
|
using difference_type = std::ptrdiff_t;
|
||||||
|
using value_type = Move;
|
||||||
|
using pointer = value_type*;
|
||||||
|
using reference = value_type&;
|
||||||
|
|
||||||
|
MoveIter(pointer ptr) : _ptr{ptr} { };
|
||||||
|
|
||||||
|
// Dereference operators
|
||||||
|
reference operator*() { return *_ptr; };
|
||||||
|
pointer operator->() { return _ptr; };
|
||||||
|
|
||||||
|
// Arithmetic operators
|
||||||
|
MoveIter& operator++() { _ptr++; return *this; };
|
||||||
|
MoveIter operator++(int) {
|
||||||
|
MoveIter copy = *this; ++(*this); return copy;
|
||||||
|
};
|
||||||
|
MoveIter& operator--() { _ptr--; return *this; };
|
||||||
|
MoveIter operator--(int) {
|
||||||
|
MoveIter copy = *this; --(*this); return copy;
|
||||||
|
};
|
||||||
|
MoveIter& operator+=(int i) {
|
||||||
|
_ptr += i; return *this;
|
||||||
|
};
|
||||||
|
friend MoveIter operator+(MoveIter it, int i) { return (it += i); };
|
||||||
|
MoveIter& operator-=(int i) {
|
||||||
|
_ptr -= i; return *this;
|
||||||
|
};
|
||||||
|
friend MoveIter operator-(MoveIter it, int i) { return (it -= i); };
|
||||||
|
|
||||||
|
friend difference_type operator-(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr - rhs._ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Comparison operators
|
||||||
|
friend bool operator==(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr == rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator<=(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr <= rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator>=(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr >= rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator!=(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr != rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator<(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr < rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator>(const MoveIter& lhs, const MoveIter& rhs) {
|
||||||
|
return lhs._ptr > rhs._ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
pointer _ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
MoveIter begin() { return MoveIter(_moves); };
|
||||||
|
MoveIter end() { return MoveIter(_moves + _size); };
|
||||||
|
|
||||||
|
struct ConstMoveIter {
|
||||||
|
using iterator_category = std::forward_iterator_tag;
|
||||||
|
using difference_type = std::ptrdiff_t;
|
||||||
|
using value_type = const Move;
|
||||||
|
using pointer = const value_type*;
|
||||||
|
using reference = const value_type&;
|
||||||
|
|
||||||
|
ConstMoveIter(pointer ptr) : _ptr{ptr} { };
|
||||||
|
|
||||||
|
// Dereference operators
|
||||||
|
reference operator*() { return *_ptr; };
|
||||||
|
pointer operator->() { return _ptr; };
|
||||||
|
|
||||||
|
// Arithmetic operators
|
||||||
|
ConstMoveIter& operator++() { _ptr++; return *this; };
|
||||||
|
ConstMoveIter operator++(int) {
|
||||||
|
ConstMoveIter copy = *this; ++(*this); return copy;
|
||||||
|
};
|
||||||
|
ConstMoveIter& operator--() { _ptr--; return *this; };
|
||||||
|
ConstMoveIter operator--(int) {
|
||||||
|
ConstMoveIter copy = *this; --(*this); return copy;
|
||||||
|
};
|
||||||
|
ConstMoveIter& operator+=(int i) {
|
||||||
|
_ptr += i; return *this;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Comparison operators
|
||||||
|
friend bool operator==(const ConstMoveIter& lhs, const ConstMoveIter& rhs) {
|
||||||
|
return lhs._ptr == rhs._ptr;
|
||||||
|
};
|
||||||
|
friend bool operator!=(const ConstMoveIter& lhs, const ConstMoveIter& rhs) {
|
||||||
|
return lhs._ptr != rhs._ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
pointer _ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
ConstMoveIter begin() const { return ConstMoveIter(_moves); };
|
||||||
|
ConstMoveIter end() const { return ConstMoveIter(_moves + _size); };
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_MOVE_H */
|
|
@ -0,0 +1,63 @@
|
||||||
|
#ifndef INCLUDE_GUARD_BIT_H
|
||||||
|
#define INCLUDE_GUARD_BIT_H
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
// Polyfill for `<bit>`, it compiled with proper standard this falls back to
|
||||||
|
// the STL implementation. Only provides functionality if not supported.
|
||||||
|
|
||||||
|
// TODO: extend to others than gcc (and maybe clang?)
|
||||||
|
// Inject polyfill for `std::endian` (since C++20)
|
||||||
|
namespace std {
|
||||||
|
enum class endian {
|
||||||
|
little = __ORDER_LITTLE_ENDIAN__,
|
||||||
|
big = __ORDER_BIG_ENDIAN__,
|
||||||
|
native = __BYTE_ORDER__
|
||||||
|
};
|
||||||
|
|
||||||
|
// Polyfill `std::byteswap` (since C++23)
|
||||||
|
template <typename T>
|
||||||
|
constexpr T byteswap(T x) noexcept {
|
||||||
|
#ifdef __GNUC__
|
||||||
|
if constexpr (sizeof(T) == 8) {
|
||||||
|
uint64_t swapped = __builtin_bswap64(*reinterpret_cast<uint64_t*>(&x));
|
||||||
|
return *reinterpret_cast<double*>(&swapped);
|
||||||
|
} else if constexpr (sizeof(T) == 4) {
|
||||||
|
uint32_t swapped = __builtin_bswap32(*reinterpret_cast<uint32_t*>(&x));
|
||||||
|
return *reinterpret_cast<double*>(&swapped);
|
||||||
|
} else if constexpr (sizeof(T) == 2) {
|
||||||
|
uint16_t swapped = __builtin_bswap16(*reinterpret_cast<uint16_t*>(&x));
|
||||||
|
return *reinterpret_cast<double*>(&swapped);
|
||||||
|
} else {
|
||||||
|
static_assert(sizeof(T) == 1);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if constexpr (sizeof(T) == 8) {
|
||||||
|
uint64_t swapped = *reinterpret_cast<uint64_t*>(&x);
|
||||||
|
swapped = (((swapped & static_cast<uint64_t>(0xFF00FF00FF00FF00ULL)) >> 8)
|
||||||
|
| ((swapped & static_cast<uint64_t>(0x00FF00FF00FF00FFULL)) << 8));
|
||||||
|
swapped = (((swapped & static_cast<uint64_t>(0xFFFF0000FFFF0000ULL)) >> 16)
|
||||||
|
| ((swapped & static_cast<uint64_t>(0x0000FFFF0000FFFFULL)) << 16));
|
||||||
|
swapped = ((swapped >> 32) | (swapped << 32));
|
||||||
|
return *reinterpret_cast<T*>(&swapped);
|
||||||
|
} else if constexpr (sizeof(T) == 4) {
|
||||||
|
uint32_t swapped = *reinterpret_cast<uint32_t*>(&x);
|
||||||
|
swapped = (((swapped & static_cast<uint32_t>(0xFF00FF00)) >> 8)
|
||||||
|
| ((swapped & static_cast<uint32_t>(0x00FF00FF)) << 8));
|
||||||
|
swapped = ((swapped >> 16) | (swapped << 16));
|
||||||
|
return *reinterpret_cast<T*>(&swapped);
|
||||||
|
} else if constexpr (sizeof(T) == 2) {
|
||||||
|
uint16_t swapped = *reinterpret_cast<uint16_t*>(&x);
|
||||||
|
swapped = ((swapped << 8) | (swapped >> 8));
|
||||||
|
return *reinterpret_cast<T*>(&swapped);
|
||||||
|
} else {
|
||||||
|
static_assert(sizeof(T) == 1);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} /* namespace std */
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_BIT_H */
|
|
@ -0,0 +1,17 @@
|
||||||
|
#ifndef INCLUDE_GUARD_NULLSTREAM_H
|
||||||
|
#define INCLUDE_GUARD_NULLSTREAM_H
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
|
class nullstream : public std::ostream {
|
||||||
|
public:
|
||||||
|
nullstream() : std::ostream(nullptr) { }
|
||||||
|
nullstream(const nullstream&) : std::ostream(nullptr) { }
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
nullstream& operator<<(const T& rhs) { return *this; }
|
||||||
|
|
||||||
|
nullstream& operator<<(std::ostream& (*fn)(std::ostream&)) { return *this; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_NULLSTREAM_H */
|
|
@ -0,0 +1,700 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <thread>
|
||||||
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include "Board.h"
|
||||||
|
#include "Move.h"
|
||||||
|
#include "uci.h"
|
||||||
|
#include "syncstream.h"
|
||||||
|
#include "search.h"
|
||||||
|
#include "HashTable.h"
|
||||||
|
|
||||||
|
namespace Search {
|
||||||
|
|
||||||
|
// Node of a transposition table and its type
|
||||||
|
enum TTType { exact, upper_bound, lower_bound, unknown };
|
||||||
|
struct TTEntry {
|
||||||
|
enum TTType type;
|
||||||
|
unsigned age;
|
||||||
|
int depth;
|
||||||
|
Score score;
|
||||||
|
Move move;
|
||||||
|
};
|
||||||
|
// Transposition Table collision policy
|
||||||
|
struct TTPolicy {
|
||||||
|
bool operator()(const TTEntry& entry, const TTEntry& candidate) {
|
||||||
|
return (entry.age < candidate.age)
|
||||||
|
|| ((entry.type == candidate.type) && (entry.depth <= candidate.depth))
|
||||||
|
|| (candidate.type == exact);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implementation specific global search state, workers and concurrency controls
|
||||||
|
State state;
|
||||||
|
HashTable<Board, TTEntry, TTPolicy> tt;
|
||||||
|
std::atomic<bool> isRunning;
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
|
||||||
|
// Set initial static age of search
|
||||||
|
template <typename Evaluator>
|
||||||
|
unsigned PVS<Evaluator>::_age = 0;
|
||||||
|
|
||||||
|
void init(unsigned ttSize) {
|
||||||
|
// Init internal state
|
||||||
|
state.depth = 16; // default max search depth
|
||||||
|
state.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
state.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
state.winc = 0;
|
||||||
|
state.binc = 0;
|
||||||
|
state.movetime = 0;
|
||||||
|
|
||||||
|
// Init Transposition Table
|
||||||
|
constexpr unsigned ttLineSize = sizeof(HashTable<Board, TTEntry, TTPolicy>::Line);
|
||||||
|
// Reserve/resize tt to ttSize in MB of memory
|
||||||
|
if (tt.reserve(ttSize * 1024U * 1024U / ttLineSize)) {
|
||||||
|
// Report error of reserve failed
|
||||||
|
osyncstream(cout) << "info string error Unable to resize "
|
||||||
|
"Transposition Table to " << ttSize << "MB" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double ttSize() {
|
||||||
|
// Init Transposition Table
|
||||||
|
constexpr unsigned ttLineSize = sizeof(HashTable<Board, TTEntry, TTPolicy>::Line);
|
||||||
|
// compute size in MB from number of lines (size)
|
||||||
|
return static_cast<double>(tt.size() * ttLineSize) / 1048576.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void newgame() {
|
||||||
|
// Reset internal state
|
||||||
|
state.depth = 16; // default max search depth
|
||||||
|
state.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
state.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
state.winc = 0;
|
||||||
|
state.binc = 0;
|
||||||
|
state.movetime = 0;
|
||||||
|
|
||||||
|
// Delete all data from TT table
|
||||||
|
tt.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perft implementation
|
||||||
|
unsigned perft_subroutine(Board& board, int depth) {
|
||||||
|
if (depth <= 0) { return 1; }
|
||||||
|
|
||||||
|
Index nodeCount = 0;
|
||||||
|
MoveList searchmoves;
|
||||||
|
board.moves(searchmoves);
|
||||||
|
|
||||||
|
if (depth == 1) {
|
||||||
|
nodeCount = searchmoves.size();
|
||||||
|
} else {
|
||||||
|
Board boardCopy = board;
|
||||||
|
for (Move move : searchmoves) {
|
||||||
|
board.make(move);
|
||||||
|
nodeCount += perft_subroutine(board, depth - 1);
|
||||||
|
board = boardCopy; // unmake move
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodeCount;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Prints for all moves in passed move list (assuming legal moves) number of
|
||||||
|
* `depth - 1` performance test node count of this child node count.
|
||||||
|
*
|
||||||
|
* @param board Reference root node board position.
|
||||||
|
* @param depth search depth.
|
||||||
|
* @param moveList List of legal moves at current board position to search.
|
||||||
|
* If moveList is empty, search all legal moves (Note: In this case, the
|
||||||
|
* moveList is filled with all legal moves manipulating the list in place).
|
||||||
|
*
|
||||||
|
* @example Output generated from start position for depth `5` with moveList
|
||||||
|
* containing the moves "d4", "Nf3" and "Nc3".
|
||||||
|
* position startpos
|
||||||
|
* go perft 5 searchmoves d2d4 g1f3 b1c3
|
||||||
|
* d2d4: 361790
|
||||||
|
* g1f3: 233491
|
||||||
|
* b1c3: 234656
|
||||||
|
*
|
||||||
|
* Nodes searched: 829937
|
||||||
|
*/
|
||||||
|
void perft(Board& board, int depth, MoveList& searchmoves) {
|
||||||
|
if (depth <= 0) {
|
||||||
|
osyncstream(cout) << "Nodes searched: 0" << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!searchmoves.size()) {
|
||||||
|
board.moves(searchmoves);
|
||||||
|
}
|
||||||
|
|
||||||
|
Index totalCount = 0;
|
||||||
|
|
||||||
|
Board boardCopy = board;
|
||||||
|
for (Move move : searchmoves) {
|
||||||
|
// check if shutdown/stopping is requested
|
||||||
|
if (!isRunning.load(std::memory_order_relaxed)) { break; }
|
||||||
|
// continue traversing
|
||||||
|
board.make(move);
|
||||||
|
Index nodeCount;
|
||||||
|
nodeCount = perft_subroutine(board, depth - 1);
|
||||||
|
totalCount += nodeCount;
|
||||||
|
board = boardCopy; // unmake move
|
||||||
|
|
||||||
|
// report moves of node
|
||||||
|
osyncstream(cout) << move << ": " << nodeCount << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
osyncstream(cout)
|
||||||
|
<< std::endl << "Nodes searched: " << totalCount << std::endl;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Specialization of `perft(Board&, int, MoveList&)` for all legal moves.
|
||||||
|
*/
|
||||||
|
void perft(Board& board, int depth) {
|
||||||
|
MoveList moveList; // init empty move list -> perft for all legal moves
|
||||||
|
perft(board, depth, moveList);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search initialization
|
||||||
|
*/
|
||||||
|
template <typename Evaluator>
|
||||||
|
PVS<Evaluator>::PVS(const std::vector<Board>& game, const State& config)
|
||||||
|
: _root(game.back())
|
||||||
|
// Restrict max search depth by the principal variation capacity
|
||||||
|
// Upper depth bound allows more than 200 plys which is enough.
|
||||||
|
, _max_depth{std::min(config.depth, static_cast<int>(MoveList::max_size()))}
|
||||||
|
// Init selective depth (max reached pvs + qsearch depth, max search ply)
|
||||||
|
, _seldepth{0}
|
||||||
|
// Set visited nodes to zero
|
||||||
|
, _nodes{0}
|
||||||
|
// break condition flag
|
||||||
|
, _isStopped{false}
|
||||||
|
// search start time for time dependend break condition and reporting
|
||||||
|
, _start_time{clock::now()}
|
||||||
|
, _searchmoves(config.searchmoves)
|
||||||
|
{
|
||||||
|
// Fill search moves with all legal moves if no moves are specified
|
||||||
|
if (_searchmoves.empty()) {
|
||||||
|
_root.moves(_searchmoves);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment age (for TT entries)
|
||||||
|
++_age;
|
||||||
|
|
||||||
|
// Time control; set time to stop searching (ignoring potental loss of 2 ms)
|
||||||
|
unsigned half_move_time;
|
||||||
|
if (config.movetime) {
|
||||||
|
half_move_time = config.movetime / 2U;
|
||||||
|
} else {
|
||||||
|
if (_root.isWhiteTurn()) {
|
||||||
|
half_move_time = std::max(config.wtime / 60U, 50U);
|
||||||
|
} else {
|
||||||
|
half_move_time = std::max(config.btime / 60U, 50U);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_mid_time = _start_time + milliseconds( half_move_time);
|
||||||
|
_end_time = _start_time + milliseconds(2U * half_move_time);
|
||||||
|
|
||||||
|
// fill repitition detection history
|
||||||
|
_historySize = _root.halfMoveClock() + 1 < game.size()
|
||||||
|
? _root.halfMoveClock() + 1
|
||||||
|
: game.size();
|
||||||
|
auto gameIter = game.rbegin();
|
||||||
|
for (size_t i = _historySize; i-- > 0; ) {
|
||||||
|
_history[i] = (*gameIter++).hash();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Principal Variation Search at the Root Position
|
||||||
|
*/
|
||||||
|
template <typename Evaluator>
|
||||||
|
Score PVS<Evaluator>::operator()() {
|
||||||
|
using std::chrono::duration_cast;
|
||||||
|
|
||||||
|
// Color keeps track of static evaluation sign for the current player
|
||||||
|
// in the negated maximization (negaMax) framework
|
||||||
|
Score color = _root.isWhiteTurn() ? +1 : -1;
|
||||||
|
|
||||||
|
// Tracks principal variation
|
||||||
|
MoveList pv_line;
|
||||||
|
|
||||||
|
// set killers all to !move (TODO: is this required?)
|
||||||
|
std::fill_n(_killers[0], MoveList::max_size(), Move(0));
|
||||||
|
std::fill_n(_killers[1], MoveList::max_size(), Move(0));
|
||||||
|
|
||||||
|
// set history heuristic to zero // TODO: can I replace that by zero-initialization?!
|
||||||
|
for (Index p = 0; p < 6; ++p) {
|
||||||
|
std::fill_n(_historyTable[0][p], 64, 0);
|
||||||
|
std::fill_n(_historyTable[1][p], 64, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterative deepening
|
||||||
|
Score prevScore = 0;
|
||||||
|
for (int depth = 1; depth <= _max_depth; depth++) {
|
||||||
|
// Clear temporary pv line
|
||||||
|
pv_line.clear();
|
||||||
|
|
||||||
|
// Aspiration window search
|
||||||
|
Score alpha, score, beta;
|
||||||
|
Score delta = 34;
|
||||||
|
// Set initial aspriation window bounds (shallow searches -> full width)
|
||||||
|
if (depth < 3) {
|
||||||
|
alpha = limits<Score>::lower();
|
||||||
|
beta = limits<Score>::upper();
|
||||||
|
} else {
|
||||||
|
alpha = prevScore - delta;
|
||||||
|
beta = prevScore + delta;
|
||||||
|
}
|
||||||
|
while (true) {
|
||||||
|
// search current depth and aspiration window
|
||||||
|
score = pvs(_root, depth, 0, alpha, beta, color, pv_line);
|
||||||
|
|
||||||
|
// Increase aspiration window
|
||||||
|
delta += 2 * delta + 5;
|
||||||
|
|
||||||
|
// Increase window (iff fail high/low)
|
||||||
|
if (score <= alpha) {
|
||||||
|
alpha = std::max(prevScore - delta, limits<Score>::lower() - 1);
|
||||||
|
} else if (beta <= score) {
|
||||||
|
beta = std::max(prevScore + delta, limits<Score>::upper() + 1);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getting deeper, set previous iteration score in iterative deepening
|
||||||
|
prevScore = score;
|
||||||
|
|
||||||
|
// Check search break condition (before updating the pv and reporting
|
||||||
|
// partial results)
|
||||||
|
if (_isStopped
|
||||||
|
|| !isRunning.load(std::memory_order_relaxed)
|
||||||
|
|| (_end_time < clock::now())) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy principal variation
|
||||||
|
_pv.clear();
|
||||||
|
_pv.append(pv_line);
|
||||||
|
|
||||||
|
auto now = clock::now();
|
||||||
|
// Time spend in search till now (only for reporting)
|
||||||
|
auto duration = duration_cast<milliseconds>(now - _start_time);
|
||||||
|
// Check if found mate and report mate instead of centi pawn score
|
||||||
|
if (limits<Score>::upper() <= (std::abs(score) + depth)) {
|
||||||
|
// Note: +-1 for converting 0-indexed ply count to a 1-indexed
|
||||||
|
// move count
|
||||||
|
if (score < 0) {
|
||||||
|
score = limits<Score>::lower() - score - 1;
|
||||||
|
} else {
|
||||||
|
score = limits<Score>::upper() - score + 1;
|
||||||
|
}
|
||||||
|
// ply to move count
|
||||||
|
score /= 2;
|
||||||
|
// Report search stats and mate in score moves
|
||||||
|
osyncstream(cout)
|
||||||
|
<< "info depth " << depth
|
||||||
|
<< " seldepth " << _seldepth
|
||||||
|
<< " score mate " << score
|
||||||
|
<< " time " << duration.count()
|
||||||
|
<< " nodes " << _nodes
|
||||||
|
<< " pv " << _pv
|
||||||
|
<< std::endl;
|
||||||
|
// stop iterative deepening, found check mate -> no need to continue
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
// Report search progress
|
||||||
|
osyncstream(cout)
|
||||||
|
<< "info depth " << depth
|
||||||
|
<< " seldepth " << _seldepth
|
||||||
|
<< " score cp " << score
|
||||||
|
<< " time " << duration.count()
|
||||||
|
<< " nodes " << _nodes
|
||||||
|
<< " pv " << _pv
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check current time againt search mid time, iff passed the mid point
|
||||||
|
// a deeper iteration in the iterative deepening will most likely not
|
||||||
|
// finish. Therefore, stop and save the time for later.
|
||||||
|
if (_mid_time <= now) { break; }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_pv.size()) {
|
||||||
|
auto out = osyncstream(cout);
|
||||||
|
out << "bestmove " << _pv[0];
|
||||||
|
if (1 < _pv.size()) {
|
||||||
|
out << " ponder " << _pv[1];
|
||||||
|
}
|
||||||
|
out << std::endl;
|
||||||
|
} else {
|
||||||
|
// If there is no move in _pv (can happen in extreme short time control)
|
||||||
|
// just make a random (first) move, better than failing.
|
||||||
|
_root.moves(_pv); // _pv is empty
|
||||||
|
if (_pv.size()) {
|
||||||
|
osyncstream(cout) << "bestmove " << _pv[0] << std::endl;
|
||||||
|
} else {
|
||||||
|
osyncstream(cout) << "bestmove !move" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before leaving search, update search state (remaining time)
|
||||||
|
auto duration = duration_cast<milliseconds>(clock::now() - _start_time);
|
||||||
|
if (_root.isWhiteTurn()) {
|
||||||
|
state.wtime += state.winc - (duration.count() + 1);
|
||||||
|
} else {
|
||||||
|
state.btime += state.binc - (duration.count() + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset running flag
|
||||||
|
isRunning.store(false, std::memory_order_relaxed);
|
||||||
|
|
||||||
|
return (_root.isWhiteTurn() ? 1 : -1) * prevScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Evaluator>
|
||||||
|
Score PVS<Evaluator>::qSearch(const PVS<Evaluator>::Position& pos, int ply,
|
||||||
|
Score alpha, Score beta, Score color
|
||||||
|
) {
|
||||||
|
// Track selective search depth counter (reported while searching)
|
||||||
|
_seldepth = std::max(_seldepth, ply);
|
||||||
|
|
||||||
|
// Static evaluation
|
||||||
|
Score score = color * pos.eval();
|
||||||
|
|
||||||
|
bool capturesOnly = true;
|
||||||
|
// Check if in check, if in check generate all moves (check evations).
|
||||||
|
// Otherwise, check if not moving at all (stand pat) is already too good.
|
||||||
|
if (pos.isCheck()) {
|
||||||
|
capturesOnly = false;
|
||||||
|
} else if (beta <= score) {
|
||||||
|
++_nodes;
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all captures (captures only, if not in check)
|
||||||
|
MoveList moveList;
|
||||||
|
pos.moves(moveList, capturesOnly);
|
||||||
|
|
||||||
|
// Check for terminal node
|
||||||
|
if (moveList.empty()) {
|
||||||
|
++_nodes;
|
||||||
|
// If there are no captures left, return static evaluation score
|
||||||
|
if (capturesOnly) {
|
||||||
|
return score;
|
||||||
|
// When generated all moves (in case of check evations), the fact that
|
||||||
|
// there are no moves means mate.
|
||||||
|
} else {
|
||||||
|
return color * (pos.isWhiteTurn() ? +1 : -1)
|
||||||
|
* (limits<Score>::lower() + ply);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure capture is better than stand pat
|
||||||
|
alpha = std::max(alpha, score);
|
||||||
|
|
||||||
|
// if (capturesOnly) {
|
||||||
|
// for (Move& move : moveList) {
|
||||||
|
// // add 65538 to ensure the set move score is positive
|
||||||
|
// move.setScore(pos.see(move) + 65538);
|
||||||
|
// }
|
||||||
|
// } else {
|
||||||
|
for (Move& move : moveList) {
|
||||||
|
// MVV-LVA capture move order
|
||||||
|
move.setScore(move.calcScore());
|
||||||
|
}
|
||||||
|
// }
|
||||||
|
std::sort(moveList.begin(), moveList.end());
|
||||||
|
|
||||||
|
// Recursively capture pieces
|
||||||
|
Position board = pos;
|
||||||
|
for (Move move : moveList) {
|
||||||
|
// SEE pruning (drop losing captures)
|
||||||
|
if (capturesOnly && (pos.see(move) < 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
board.make(move);
|
||||||
|
|
||||||
|
score = -qSearch(board, ply + 1, -beta, -alpha, -color);
|
||||||
|
|
||||||
|
if (beta <= score) {
|
||||||
|
return beta; // fail-hard beta-cutoff
|
||||||
|
}
|
||||||
|
alpha = std::max(alpha, score);
|
||||||
|
|
||||||
|
board = pos; // unmake
|
||||||
|
}
|
||||||
|
|
||||||
|
return alpha;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Principal Variation Search routine (all non-root) positions
|
||||||
|
*/
|
||||||
|
template <class Evaluator>
|
||||||
|
Score PVS<Evaluator>::pvs(const PVS<Evaluator>::Position& pos, int depth, int ply,
|
||||||
|
Score alpha, Score beta, Score color, MoveList& pv_line
|
||||||
|
) {
|
||||||
|
// Check for search shutdown and in case of a repetition return draw score
|
||||||
|
if (_isStopped || isRepetition(pos)) { return 0; }
|
||||||
|
// Check break condition every few thousend nodes
|
||||||
|
if (!(_nodes % (32 * 1024))) {
|
||||||
|
// Just return 0, partial search will be discarded
|
||||||
|
if (!isRunning.load(std::memory_order_relaxed)
|
||||||
|
|| (_end_time < clock::now())) {
|
||||||
|
_isStopped = true;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save original alpha bound before TT entry update
|
||||||
|
Score original_alpha = alpha;
|
||||||
|
|
||||||
|
// Move list containing all moves to be searched (usually all legal moves)
|
||||||
|
MoveList moveList;
|
||||||
|
if (!ply) {
|
||||||
|
moveList.append(_searchmoves);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transposition table lookup
|
||||||
|
auto tt_line = tt.find(pos);
|
||||||
|
if ((tt_line != tt.end()) && ((*tt_line).depth >= depth)) {
|
||||||
|
// Need to create all moves here to check if the TT move is legal
|
||||||
|
// (or restricted to searchmoves)
|
||||||
|
if (moveList.empty()) {
|
||||||
|
pos.moves(moveList);
|
||||||
|
}
|
||||||
|
// Only use the TT entry if TT move is legal
|
||||||
|
if (moveList.contains((*tt_line).move)) {
|
||||||
|
switch ((*tt_line).type) {
|
||||||
|
case lower_bound:
|
||||||
|
alpha = std::max(alpha, (*tt_line).score);
|
||||||
|
break;
|
||||||
|
case exact:
|
||||||
|
++_nodes;
|
||||||
|
pv_line.clear();
|
||||||
|
pv_line.push_back((*tt_line).move);
|
||||||
|
return (*tt_line).score;
|
||||||
|
case upper_bound:
|
||||||
|
beta = std::min(beta, (*tt_line).score);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert((false && "Unknown TT node type entry in TT lookup"));
|
||||||
|
}
|
||||||
|
// Check altered alpha/beta bounds
|
||||||
|
if (beta <= alpha) {
|
||||||
|
++_nodes;
|
||||||
|
return (*tt_line).score;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Reset TT entry to "not-found"
|
||||||
|
tt_line = tt.end();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max search depth reached
|
||||||
|
if (depth <= 0) {
|
||||||
|
return qSearch(pos, ply, alpha, beta, color);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all moves (if not already generated or constraint)
|
||||||
|
if (moveList.empty()) {
|
||||||
|
pos.moves(moveList);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check terminal node
|
||||||
|
if (moveList.empty()) {
|
||||||
|
// Increment node count
|
||||||
|
++_nodes;
|
||||||
|
// Evaluate terminal node score (check/stale mate)
|
||||||
|
if (pos.isCheck()) {
|
||||||
|
return color * (pos.isWhiteTurn() ? +1 : -1)
|
||||||
|
* (limits<Score>::lower() + ply);
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort moves, iff hash move its the first, then captures in MVV-LVA
|
||||||
|
// followed by the rest
|
||||||
|
auto moveOrder = moveList.begin();
|
||||||
|
if (tt_line != tt.end()) {
|
||||||
|
if (moveList.move_front((*tt_line).move)) {
|
||||||
|
++moveOrder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// score moves
|
||||||
|
for (Move& move : moveList) {
|
||||||
|
auto mScore = move.calcScore();
|
||||||
|
// Add killer move score -> killers after captures
|
||||||
|
mScore += (move == _killers[0][ply]) * Move::killerScore[0];
|
||||||
|
mScore += (move == _killers[1][ply]) * Move::killerScore[1];
|
||||||
|
// add history heuristic score -> sort of non-captures
|
||||||
|
mScore += static_cast<bool>(move.victim())
|
||||||
|
* _historyTable[move.color()][move.piece() - 2][move.to()];
|
||||||
|
move.setScore(mScore);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: best move selection instead of sort
|
||||||
|
std::sort(moveOrder, moveList.end(),
|
||||||
|
[](Move::base_type lhs, Move::base_type rhs) { return lhs > rhs; });
|
||||||
|
|
||||||
|
|
||||||
|
// New child pv line
|
||||||
|
MoveList cur_line;
|
||||||
|
|
||||||
|
// New TT entry (tracks "best" move)
|
||||||
|
TTEntry tt_entry{unknown, _age, depth, 0, Move{}};
|
||||||
|
|
||||||
|
// Current board
|
||||||
|
Position board = pos;
|
||||||
|
|
||||||
|
Index moveCount = 0;
|
||||||
|
Score score = limits<Score>::lower();
|
||||||
|
for (Move move : moveList) {
|
||||||
|
board.make(move);
|
||||||
|
pushHistory(board);
|
||||||
|
|
||||||
|
Score value;
|
||||||
|
if (!moveCount) {
|
||||||
|
value = -pvs(board, depth - 1, ply + 1,
|
||||||
|
-beta, -alpha, -color, cur_line);
|
||||||
|
} else {
|
||||||
|
// LMR (Late Move Reduction), reduce search depth of later moves
|
||||||
|
int R = 0; // TODO: disabled!!!
|
||||||
|
// int R = (moveCount > 3) & (ply > 2) & !move.victim() & !pos.isCheck();
|
||||||
|
|
||||||
|
// Zero-window search (with possible reduction)
|
||||||
|
value = -pvs(board, depth - 1 - R, ply + 1,
|
||||||
|
-(alpha + 1), -alpha, -color, cur_line);
|
||||||
|
if (alpha < value) {
|
||||||
|
// re-search (with clean current pv line and without reduction)
|
||||||
|
cur_line.clear();
|
||||||
|
value = -pvs(board, depth - 1, ply + 1,
|
||||||
|
-beta, -alpha, -color, cur_line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++moveCount;
|
||||||
|
score = std::max(score, value);
|
||||||
|
|
||||||
|
// Check for beta cutoff
|
||||||
|
if (beta <= score) {
|
||||||
|
// Set possible lower bound hash hash move
|
||||||
|
tt_entry.move = move;
|
||||||
|
// Store killer move which is a quiet move causing a beta cut-off
|
||||||
|
if (!move.victim()) {
|
||||||
|
// If move isn't the first killer move we insert which ensures
|
||||||
|
// (assuming killers are different or !move) that two different
|
||||||
|
// killers are stored
|
||||||
|
if (_killers[0][ply] != move) {
|
||||||
|
_killers[1][ply] = _killers[0][ply];
|
||||||
|
_killers[0][ply] = move;
|
||||||
|
}
|
||||||
|
// Increment history heuristic
|
||||||
|
_historyTable[move.color()][move.piece() - 2][move.to()] += depth * depth;
|
||||||
|
}
|
||||||
|
popHistory();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Check if alpha raised; iff add new pv/best move
|
||||||
|
if (alpha < score) {
|
||||||
|
alpha = score;
|
||||||
|
// Track pv (implicitly the best move)
|
||||||
|
pv_line.clear();
|
||||||
|
pv_line.push_back(move);
|
||||||
|
pv_line.append(cur_line);
|
||||||
|
// Track best move as hash move
|
||||||
|
tt_entry.move = move;
|
||||||
|
}
|
||||||
|
|
||||||
|
board = pos; // unmake move
|
||||||
|
popHistory();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add TT entry
|
||||||
|
tt_entry.score = score;
|
||||||
|
if (score <= original_alpha) {
|
||||||
|
tt_entry.type = upper_bound;
|
||||||
|
} else if (beta <= score) {
|
||||||
|
tt_entry.type = lower_bound;
|
||||||
|
} else {
|
||||||
|
tt_entry.type = exact;
|
||||||
|
}
|
||||||
|
tt.insert(pos, tt_entry);
|
||||||
|
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Starts searching the given board position by dispatching a worker thread
|
||||||
|
void start(std::vector<Board>& game, State& config) {
|
||||||
|
if (isRunning.load(std::memory_order_consume)) {
|
||||||
|
// Ignore further search start attempts if allready searching, requires
|
||||||
|
// a stop first!
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
// Dispose of finished/stopped workers
|
||||||
|
for (std::thread& worker : workers) {
|
||||||
|
if (worker.joinable()) {
|
||||||
|
worker.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
workers.clear();
|
||||||
|
}
|
||||||
|
// sets worker thread stop condition to false (before dispatch)
|
||||||
|
isRunning.store(true, std::memory_order_release);
|
||||||
|
|
||||||
|
// Dispatch search worker
|
||||||
|
switch (config.search) {
|
||||||
|
case State::Type::perft:
|
||||||
|
state.search = State::Type::perft;
|
||||||
|
workers.emplace_back([&]() {
|
||||||
|
// Copy working variables, not subject to change
|
||||||
|
Board pos(game.back());
|
||||||
|
MoveList searchmoves = config.searchmoves;
|
||||||
|
// Start perft
|
||||||
|
perft(pos, config.depth, searchmoves);
|
||||||
|
// Reset running flag
|
||||||
|
isRunning.store(false, std::memory_order_relaxed);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case State::Type::search:
|
||||||
|
std::cout << "info string HCE evaluation" << std::endl;
|
||||||
|
workers.emplace_back(PVS<Board>(game, config));
|
||||||
|
break;
|
||||||
|
case State::Type::ponder:
|
||||||
|
std::cerr << "info string error pondering not implemented!" << std::endl;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert((false && "Search::start got request for unknown search type."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void stop() {
|
||||||
|
// revoke isRunning flag -> workers stop as soon as possible
|
||||||
|
isRunning.store(false, std::memory_order_relaxed);
|
||||||
|
|
||||||
|
// then join and dispose all workers
|
||||||
|
for (std::thread& worker : workers) {
|
||||||
|
if (worker.joinable()) {
|
||||||
|
worker.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
workers.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit instantiations of PVS types using different evaluators
|
||||||
|
template class PVS<Board>;
|
||||||
|
|
||||||
|
} /* namespace Search */
|
|
@ -0,0 +1,203 @@
|
||||||
|
#ifndef INCLUDE_GUARD_SEARCH_H
|
||||||
|
#define INCLUDE_GUARD_SEARCH_H
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <vector>
|
||||||
|
#include <chrono>
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include "Move.h"
|
||||||
|
#include "Board.h"
|
||||||
|
|
||||||
|
namespace Search {
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
enum class Type { perft, search, ponder };
|
||||||
|
|
||||||
|
enum Type search;
|
||||||
|
int depth;
|
||||||
|
unsigned wtime;
|
||||||
|
unsigned btime;
|
||||||
|
unsigned winc;
|
||||||
|
unsigned binc;
|
||||||
|
unsigned movetime;
|
||||||
|
MoveList searchmoves;
|
||||||
|
|
||||||
|
State() = default;
|
||||||
|
State(const State& state)
|
||||||
|
: search{Type::search}
|
||||||
|
, depth{state.depth}
|
||||||
|
, wtime{state.wtime}
|
||||||
|
, btime{state.btime}
|
||||||
|
, winc{state.winc}
|
||||||
|
, binc{state.binc}
|
||||||
|
, movetime{state.movetime}
|
||||||
|
, searchmoves() { };
|
||||||
|
};
|
||||||
|
|
||||||
|
extern State state;
|
||||||
|
extern std::atomic<bool> isRunning;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization of search specific (global) variables
|
||||||
|
*/
|
||||||
|
void init(unsigned ttSize = 32);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Diagnostic routine to get the allocated TT size in MB
|
||||||
|
*/
|
||||||
|
double ttSize();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resets all search internal variables to initial configuration state
|
||||||
|
*/
|
||||||
|
void newgame();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performance Test Subroutine, simple perft function without any I/O.
|
||||||
|
*
|
||||||
|
* Used in `perft.cpp` testing utility (therefore accessable).
|
||||||
|
*
|
||||||
|
* @param board Reference root node board position.
|
||||||
|
* @param depth search depth.
|
||||||
|
*
|
||||||
|
* @return number of moves/nodes from root node of search depth.
|
||||||
|
*/
|
||||||
|
unsigned perft_subroutine(Board& board, int depth);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Board wrapper in combination with evaluation to abstract the different
|
||||||
|
* position evaluation options (HCE Board evaluate and NNUE evaluate)
|
||||||
|
*
|
||||||
|
* @note The NNUE evaluation is removed from this code base!
|
||||||
|
*/
|
||||||
|
template <class CRTP>
|
||||||
|
class GameState : public Board {
|
||||||
|
public:
|
||||||
|
GameState(const Board& pos) : Board(pos) { };
|
||||||
|
|
||||||
|
Score eval() { return static_cast<CRTP*>(this)->eval(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specialized CRTP derived HCE Board evaluation position
|
||||||
|
*/
|
||||||
|
class BoardState : public GameState<BoardState> {
|
||||||
|
public:
|
||||||
|
BoardState(const Board& pos) : GameState<BoardState>(pos) { };
|
||||||
|
|
||||||
|
Score eval() const { return this->evaluate(); }
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Search routine class handling all search relevant data and executes the search
|
||||||
|
*
|
||||||
|
* This is a Principal Variabtion Search with alpha-beta pruning and a
|
||||||
|
* transposition table implemented in an iterative deepening framework.
|
||||||
|
*
|
||||||
|
* There are two Evaluation types, ether the HCE Board evaluation `PVS<Board>`
|
||||||
|
* or the NNUE evaluation `PVS<NNUE>`.
|
||||||
|
*
|
||||||
|
* @note The NNUE evaluation is removed from this code base!
|
||||||
|
*/
|
||||||
|
template <typename Evaluator>
|
||||||
|
class PVS {
|
||||||
|
public:
|
||||||
|
PVS(const std::vector<Board>&, const State&);
|
||||||
|
|
||||||
|
Score operator()();
|
||||||
|
|
||||||
|
Move bestMove() const { return _pv[0]; }
|
||||||
|
|
||||||
|
// TODO: there is a BUG collecting the PV!
|
||||||
|
const MoveList& pv() const { return _pv; }
|
||||||
|
|
||||||
|
// Analytic routine to get the number of searched leave nodes
|
||||||
|
unsigned nodes() const { return _nodes; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
using clock = std::chrono::high_resolution_clock;
|
||||||
|
using milliseconds = std::chrono::milliseconds;
|
||||||
|
using time_point = decltype(clock::now());
|
||||||
|
using Position = BoardState;
|
||||||
|
|
||||||
|
// Root search position (HCE Board state)
|
||||||
|
Position _root;
|
||||||
|
// Max search depth (excluding QSearch)
|
||||||
|
int _max_depth;
|
||||||
|
|
||||||
|
// Selective search counter (needs tracking)
|
||||||
|
int _seldepth;
|
||||||
|
// Number of nodes visited (leave nodes)
|
||||||
|
unsigned _nodes;
|
||||||
|
// Set to exit the search as soon as possible (breaks out of all recursions)
|
||||||
|
bool _isStopped;
|
||||||
|
// Start time of the search (for break condition check)
|
||||||
|
time_point _start_time;
|
||||||
|
// Middle time between _start_time and _end_time used to check if an
|
||||||
|
time_point _mid_time;
|
||||||
|
// Time till the search should stop (if now() is bigger, stop the search)
|
||||||
|
time_point _end_time;
|
||||||
|
// Moves to be searched, if empty all legal moves are searched
|
||||||
|
MoveList _searchmoves;
|
||||||
|
// Principal Variation
|
||||||
|
MoveList _pv;
|
||||||
|
// Killer Moves (initialized to contain !move entries)
|
||||||
|
Move _killers[2][MoveList::max_size()];
|
||||||
|
// History Heuristic Table (for move ordering)
|
||||||
|
Score _historyTable[2][6][64]; // [color][piece][to Square]
|
||||||
|
// Number of elements in `_history`
|
||||||
|
size_t _historySize;
|
||||||
|
// Game history list storing previous board hashes for repetition detection
|
||||||
|
u64 _history[100 + MoveList::max_size()];
|
||||||
|
// Increasing age for TT entry
|
||||||
|
static unsigned _age;
|
||||||
|
|
||||||
|
// adds a position to the history
|
||||||
|
void pushHistory(const Position& pos) {
|
||||||
|
_history[_historySize++] = pos.hash();
|
||||||
|
}
|
||||||
|
// and removes the last position from the history
|
||||||
|
void popHistory() { --_historySize; }
|
||||||
|
|
||||||
|
// two-fold repetition detection
|
||||||
|
bool isRepetition(const Position& pos) const {
|
||||||
|
u64 hash = pos.hash();
|
||||||
|
|
||||||
|
for (size_t i = _historySize - 1; i-- > 0; ) {
|
||||||
|
if (_history[i] == hash) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Principal Variation Search routine (for not root nodes)
|
||||||
|
Score pvs(const Position& pos, int depth, int ply,
|
||||||
|
Score alpha, Score beta, Score color, MoveList& pv_line);
|
||||||
|
// Quiesence Search
|
||||||
|
Score qSearch(const Position& pos, int ply,
|
||||||
|
Score alpha, Score beta, Score color);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Starts searching on the given board position.
|
||||||
|
*/
|
||||||
|
void start(std::vector<Board>& game, struct State& config);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops all workers.
|
||||||
|
*
|
||||||
|
* Tells all workers to stop as soon as possible and closes working threads.
|
||||||
|
*
|
||||||
|
* @note: This function blocks
|
||||||
|
*/
|
||||||
|
void stop();
|
||||||
|
|
||||||
|
|
||||||
|
} /* namespace Search */
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_SEARCH_H */
|
|
@ -0,0 +1,38 @@
|
||||||
|
#ifndef INCLUDE_GUARD_SYNCSTREAM_H
|
||||||
|
#define INCLUDE_GUARD_SYNCSTREAM_H
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
class osyncstream {
|
||||||
|
public:
|
||||||
|
osyncstream(std::ostream& base_ostream)
|
||||||
|
: _base_ostream(base_ostream)
|
||||||
|
{
|
||||||
|
mtx().lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
~osyncstream() {
|
||||||
|
mtx().unlock();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
osyncstream& operator<<(const T& rhs) {
|
||||||
|
_base_ostream << rhs;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
osyncstream& operator<<(std::ostream& (*fn)(std::ostream&)) {
|
||||||
|
_base_ostream << fn;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex& mtx() {
|
||||||
|
static std::mutex _mtx;
|
||||||
|
return _mtx;
|
||||||
|
};
|
||||||
|
std::ostream& _base_ostream;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_SYNCSTREAM_H */
|
|
@ -0,0 +1,208 @@
|
||||||
|
#ifndef INCLUDE_GUARD_TYPES_H
|
||||||
|
#define INCLUDE_GUARD_TYPES_H
|
||||||
|
|
||||||
|
#include <cstdint> // uint64_t
|
||||||
|
#include <limits> // std::numeric_limits
|
||||||
|
|
||||||
|
/** square, file and rank index (index > 63 indicates illegal or off board) */
|
||||||
|
using Index = unsigned;
|
||||||
|
/** Bit board, exactly 64 bits (one bit per square) */
|
||||||
|
using u64 = uint64_t; // easy on the eyes (well, my eyes)
|
||||||
|
/**
|
||||||
|
* Board position score from white point of view in centipawns.
|
||||||
|
* (1 pawn ~ 100 centipawns)
|
||||||
|
*/
|
||||||
|
using Score = int;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct limits {
|
||||||
|
static constexpr T upper();
|
||||||
|
static constexpr T lower();
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct limits<Score> {
|
||||||
|
static constexpr Score upper() { return static_cast<Score>(+32768); };
|
||||||
|
static constexpr Score lower() { return static_cast<Score>(-32768); };
|
||||||
|
static constexpr bool isMate(const Score score) {
|
||||||
|
constexpr Score mateBound = upper() - 512;
|
||||||
|
return (score < -mateBound) || (mateBound < score);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
enum piece {
|
||||||
|
none = 0,
|
||||||
|
white = 0,
|
||||||
|
black = 1,
|
||||||
|
pawn = 2,
|
||||||
|
knight = 3,
|
||||||
|
bishop = 4,
|
||||||
|
rook = 5,
|
||||||
|
queen = 6,
|
||||||
|
king = 7,
|
||||||
|
kingEG = 8 // Lookup index for king end game PSQT
|
||||||
|
};
|
||||||
|
|
||||||
|
enum square : Index {
|
||||||
|
a8, b8, c8, d8, e8, f8, g8, h8,
|
||||||
|
a7, b7, c7, d7, e7, f7, g7, h7,
|
||||||
|
a6, b6, c6, d6, e6, f6, g6, h6,
|
||||||
|
a5, b5, c5, d5, e5, f5, g5, h5,
|
||||||
|
a4, b4, c4, d4, e4, f4, g4, h4,
|
||||||
|
a3, b3, c3, d3, e3, f3, g3, h3,
|
||||||
|
a2, b2, c2, d2, e2, f2, g2, h2,
|
||||||
|
a1, b1, c1, d1, e1, f1, g1, h1
|
||||||
|
};
|
||||||
|
|
||||||
|
enum location {
|
||||||
|
Square,
|
||||||
|
Up, Down, Left, Right,
|
||||||
|
QueenSide = Left, KingSide = Right,
|
||||||
|
File, Rank,
|
||||||
|
Diag, AntiDiag,
|
||||||
|
RightUp, RightDown,
|
||||||
|
LeftUp, LeftDown,
|
||||||
|
Plus, Cross, Star,
|
||||||
|
WhiteSquares, BlackSquares
|
||||||
|
};
|
||||||
|
|
||||||
|
// Material weighting per piece
|
||||||
|
constexpr Score pieceValues[8] = {
|
||||||
|
0, 0, // white, black (irrelevant)
|
||||||
|
100, // pawn
|
||||||
|
295, // knight
|
||||||
|
315, // bishop
|
||||||
|
450, // rook
|
||||||
|
870, // queen
|
||||||
|
0 // king (irrelevant, always 2 opposite kings)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Move lookup tables for knights and kings
|
||||||
|
constexpr u64 knightMoveLookup[64] = {
|
||||||
|
0x0000000000020400, 0x0000000000050800, 0x00000000000a1100, 0x0000000000142200,
|
||||||
|
0x0000000000284400, 0x0000000000508800, 0x0000000000a01000, 0x0000000000402000,
|
||||||
|
0x0000000002040004, 0x0000000005080008, 0x000000000a110011, 0x0000000014220022,
|
||||||
|
0x0000000028440044, 0x0000000050880088, 0x00000000a0100010, 0x0000000040200020,
|
||||||
|
0x0000000204000402, 0x0000000508000805, 0x0000000a1100110a, 0x0000001422002214,
|
||||||
|
0x0000002844004428, 0x0000005088008850, 0x000000a0100010a0, 0x0000004020002040,
|
||||||
|
0x0000020400040200, 0x0000050800080500, 0x00000a1100110a00, 0x0000142200221400,
|
||||||
|
0x0000284400442800, 0x0000508800885000, 0x0000a0100010a000, 0x0000402000204000,
|
||||||
|
0x0002040004020000, 0x0005080008050000, 0x000a1100110a0000, 0x0014220022140000,
|
||||||
|
0x0028440044280000, 0x0050880088500000, 0x00a0100010a00000, 0x0040200020400000,
|
||||||
|
0x0204000402000000, 0x0508000805000000, 0x0a1100110a000000, 0x1422002214000000,
|
||||||
|
0x2844004428000000, 0x5088008850000000, 0xa0100010a0000000, 0x4020002040000000,
|
||||||
|
0x0400040200000000, 0x0800080500000000, 0x1100110a00000000, 0x2200221400000000,
|
||||||
|
0x4400442800000000, 0x8800885000000000, 0x100010a000000000, 0x2000204000000000,
|
||||||
|
0x0004020000000000, 0x0008050000000000, 0x00110a0000000000, 0x0022140000000000,
|
||||||
|
0x0044280000000000, 0x0088500000000000, 0x0010a00000000000, 0x0020400000000000
|
||||||
|
};
|
||||||
|
constexpr u64 kingMoveLookup[64] = {
|
||||||
|
0x0000000000000302, 0x0000000000000705, 0x0000000000000E0A, 0x0000000000001C14,
|
||||||
|
0x0000000000003828, 0x0000000000007050, 0x000000000000E0A0, 0x000000000000C040,
|
||||||
|
0x0000000000030203, 0x0000000000070507, 0x00000000000E0A0E, 0x00000000001C141C,
|
||||||
|
0x0000000000382838, 0x0000000000705070, 0x0000000000E0A0E0, 0x0000000000C040C0,
|
||||||
|
0x0000000003020300, 0x0000000007050700, 0x000000000E0A0E00, 0x000000001C141C00,
|
||||||
|
0x0000000038283800, 0x0000000070507000, 0x00000000E0A0E000, 0x00000000C040C000,
|
||||||
|
0x0000000302030000, 0x0000000705070000, 0x0000000E0A0E0000, 0x0000001C141C0000,
|
||||||
|
0x0000003828380000, 0x0000007050700000, 0x000000E0A0E00000, 0x000000C040C00000,
|
||||||
|
0x0000030203000000, 0x0000070507000000, 0x00000E0A0E000000, 0x00001C141C000000,
|
||||||
|
0x0000382838000000, 0x0000705070000000, 0x0000E0A0E0000000, 0x0000C040C0000000,
|
||||||
|
0x0003020300000000, 0x0007050700000000, 0x000E0A0E00000000, 0x001C141C00000000,
|
||||||
|
0x0038283800000000, 0x0070507000000000, 0x00E0A0E000000000, 0x00C040C000000000,
|
||||||
|
0x0302030000000000, 0x0705070000000000, 0x0E0A0E0000000000, 0x1C141C0000000000,
|
||||||
|
0x3828380000000000, 0x7050700000000000, 0xE0A0E00000000000, 0xC040C00000000000,
|
||||||
|
0x0203000000000000, 0x0507000000000000, 0x0A0E000000000000, 0x141C000000000000,
|
||||||
|
0x2838000000000000, 0x5070000000000000, 0xA0E0000000000000, 0x40C0000000000000
|
||||||
|
};
|
||||||
|
|
||||||
|
// Declare I/O streams (allows to globaly replace the I/O streams)
|
||||||
|
#ifdef RCPP_RCOUT
|
||||||
|
#include <Rcpp.h>
|
||||||
|
// Set I/O streams to Rcpp I/O streams
|
||||||
|
static Rcpp::Rostream<true> cout;
|
||||||
|
static Rcpp::Rostream<false> cerr;
|
||||||
|
#elif NULLSTREAM
|
||||||
|
#include "nullstream.h"
|
||||||
|
// Set I/O streams to "null"
|
||||||
|
static nullstream cout;
|
||||||
|
static nullstream cerr;
|
||||||
|
#else
|
||||||
|
#include <iostream>
|
||||||
|
// Default STL I/O streams
|
||||||
|
using std::cout;
|
||||||
|
using std::cerr;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Piece SQuare Tables (partially automated tuned tables via supervised
|
||||||
|
// optimization using stockfish [https://stockfishchess.org/] evaluated positions
|
||||||
|
// from the lichess database [https://database.lichess.org/])
|
||||||
|
// endgame table: https://www.chessprogramming.org/Simplified_Evaluation_Function
|
||||||
|
// Which is addapted by adding 50. then scaled by 2 / 3 and rounded.
|
||||||
|
constexpr Score PSQT[9][64] = {
|
||||||
|
{ }, { }, // white, black (empty)
|
||||||
|
{ // pawn (white)
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
109, 82, 89, 25, 25, 89, 82, 109,
|
||||||
|
21, 18, -3, 18, 18, -3, 18, 21,
|
||||||
|
-12, -1, -19, 6, 6, -19, -1, -12,
|
||||||
|
-25, -15, -22, 9, 9, -22, -15, -25,
|
||||||
|
-25, -11, -27, -23, -23, -27, -11, -25,
|
||||||
|
-25, -13, -23, -29, -29, -23, -13, -25,
|
||||||
|
0, 0, 0, 0, 0, 0, 0, 0 },
|
||||||
|
{ // knight (white)
|
||||||
|
-90, -80, -18, 26, 26, -18, -80, -90,
|
||||||
|
-40, -13, 21, -22, -22, 21, -13, -40,
|
||||||
|
6, 2, 32, 38, 38, 32, 2, 6,
|
||||||
|
-9, -11, 22, 20, 20, 22, -11, -9,
|
||||||
|
-13, -11, 14, 2, 2, 14, -11, -13,
|
||||||
|
-25, -10, 2, 3, 3, 2, -10, -25,
|
||||||
|
-21, -54, -12, -8, -8, -12, -54, -21,
|
||||||
|
-76, -21, -38, -34, -34, -38, -21, -76 },
|
||||||
|
{ // bishop (white)
|
||||||
|
-7, 19, 3, -21, -21, 3, 19, -7,
|
||||||
|
-15, -5, 6, 40, 40, 6, -5, -15,
|
||||||
|
12, 14, 18, 32, 32, 18, 14, 12,
|
||||||
|
-5, -2, 17, 26, 26, 17, -2, -5,
|
||||||
|
-19, -2, 2, 8, 8, 2, -2, -19,
|
||||||
|
2, 4, 2, 8, 8, 2, 4, 2,
|
||||||
|
-4, 8, 3, 1, 1, 3, 8, -4,
|
||||||
|
-31, -13, -7, -20, -20, -7, -13, -31 },
|
||||||
|
{ // rook (white)
|
||||||
|
-5, -2, 23, 40, 40, 23, -2, -5,
|
||||||
|
18, 17, 42, 25, 25, 42, 17, 18,
|
||||||
|
22, 14, 33, 40, 40, 33, 14, 22,
|
||||||
|
21, 16, 20, 28, 28, 20, 16, 21,
|
||||||
|
-4, -13, -5, 3, 3, -5, -13, -4,
|
||||||
|
-20, -2, -3, -2, -2, -3, -2, -20,
|
||||||
|
-11, -13, 0, -6, -6, 0, -13, -11,
|
||||||
|
-17, -4, 0, 7, 7, 0, -4, -17 },
|
||||||
|
{ // queen (white)
|
||||||
|
-55, -29, 59, 19, 19, 59, -29, -55,
|
||||||
|
12, -18, 34, 85, 85, 34, -18, 12,
|
||||||
|
33, 17, 31, 34, 34, 31, 17, 33,
|
||||||
|
51, 16, 21, 18, 18, 21, 16, 51,
|
||||||
|
-3, 24, 18, 26, 26, 18, 24, -3,
|
||||||
|
11, 14, 24, 2, 2, 24, 14, 11,
|
||||||
|
28, 5, 17, 15, 15, 17, 5, 28,
|
||||||
|
1, -10, -14, 18, 18, -14, -10, 1 },
|
||||||
|
{ // king middle game (white)
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-5, -5, -5, -5, -5, -5, -5, -5,
|
||||||
|
-4, -4, -4, -4, -4, -4, -4, -4,
|
||||||
|
24, 13, 3, -28, 2, -14, 15, 1 },
|
||||||
|
{ // king end game (white) // TODO: self/supervised tuning
|
||||||
|
0, 7, 13, 20, 20, 13, 7, 0,
|
||||||
|
13, 20, 27, 33, 33, 27, 20, 13,
|
||||||
|
13, 27, 47, 53, 53, 47, 27, 13,
|
||||||
|
13, 27, 53, 60, 60, 53, 27, 13,
|
||||||
|
13, 27, 53, 60, 60, 53, 27, 13,
|
||||||
|
13, 27, 47, 53, 53, 47, 27, 13,
|
||||||
|
13, 13, 33, 33, 33, 33, 13, 13,
|
||||||
|
0, 13, 13, 13, 13, 13, 13, 0 }
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_TYPES_H */
|
|
@ -0,0 +1,964 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
#include <tuple>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
#include "Move.h"
|
||||||
|
#include "Board.h"
|
||||||
|
#include "uci.h"
|
||||||
|
#include "search.h"
|
||||||
|
#if __cplusplus > 201703L
|
||||||
|
#include <syncstream>
|
||||||
|
using osyncstream = std::osyncstream;
|
||||||
|
#else
|
||||||
|
#include "syncstream.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Inject into std namespace
|
||||||
|
namespace std {
|
||||||
|
ostream& operator<<(ostream& out, const Move& move) {
|
||||||
|
if (!move) {
|
||||||
|
out << "!move";
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (UCI::writeLAN) {
|
||||||
|
switch (move.piece()) {
|
||||||
|
case pawn: break;
|
||||||
|
case knight: out << 'N'; break;
|
||||||
|
case bishop: out << 'B'; break;
|
||||||
|
case rook: out << 'R'; break;
|
||||||
|
case queen: out << 'Q'; break;
|
||||||
|
case king: out << 'K'; break;
|
||||||
|
default: out << '?';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << static_cast<char>('a' + (move.from() % 8))
|
||||||
|
<< static_cast<char>('8' - (move.from() / 8));
|
||||||
|
|
||||||
|
if (UCI::writeLAN) {
|
||||||
|
switch (move.victim()) {
|
||||||
|
case pawn: out << 'x'; break;
|
||||||
|
case knight: out << "xN"; break;
|
||||||
|
case bishop: out << "xB"; break;
|
||||||
|
case rook: out << "xR"; break;
|
||||||
|
case queen: out << "xQ"; break;
|
||||||
|
default: out << '-';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << static_cast<char>('a' + (move.to() % 8))
|
||||||
|
<< static_cast<char>('8' - (move.to() / 8));
|
||||||
|
|
||||||
|
if (UCI::writeLAN) {
|
||||||
|
switch (move.promote()) {
|
||||||
|
case rook: out << 'R'; break;
|
||||||
|
case knight: out << 'N'; break;
|
||||||
|
case bishop: out << 'B'; break;
|
||||||
|
case queen: out << 'Q'; break;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch (move.promote()) {
|
||||||
|
case rook: out << 'r'; break;
|
||||||
|
case knight: out << 'n'; break;
|
||||||
|
case bishop: out << 'b'; break;
|
||||||
|
case queen: out << 'q'; break;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
ostream& operator<<(ostream& out, const MoveList& moveList) {
|
||||||
|
if (moveList.size()) {
|
||||||
|
out << moveList[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned i = 1; i < moveList.size(); i++) {
|
||||||
|
out << ' ' << moveList[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace UCI {
|
||||||
|
|
||||||
|
// General UCI protocol configuration
|
||||||
|
bool readSAN = false; // if moves should be read in SAN
|
||||||
|
bool writeLAN = false; // moves written in Long Algebraic Notation
|
||||||
|
|
||||||
|
Index parseSquare(const std::string& str, bool& parseError) {
|
||||||
|
if (str.length() != 2) {
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<Index>(64);
|
||||||
|
}
|
||||||
|
if (str[0] < 'a' || 'h' < str[0] || str[1] < '1' || '8' < str[1]) {
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<Index>(64);
|
||||||
|
}
|
||||||
|
|
||||||
|
Index fileIndex = static_cast<Index>(str[0] - 'a');
|
||||||
|
Index rankIndex = static_cast<Index>('8' - str[1]);
|
||||||
|
|
||||||
|
return 8 * rankIndex + fileIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
Move parseMove(const std::string& str, bool& parseError) {
|
||||||
|
// check length <from><to> = 4 plus 1 if promotion
|
||||||
|
if ((str.length() < 4) || (5 < str.length())) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse <from><to> squares
|
||||||
|
Index fromSq = parseSquare(str.substr(0, 2), parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
Index toSq = parseSquare(str.substr(2, 2), parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
|
||||||
|
// handle promotions
|
||||||
|
if (str.length() == 5) {
|
||||||
|
switch (str[4]) {
|
||||||
|
case 'r': return Move(fromSq, toSq, rook);
|
||||||
|
case 'n': return Move(fromSq, toSq, knight);
|
||||||
|
case 'b': return Move(fromSq, toSq, bishop);
|
||||||
|
case 'q': return Move(fromSq, toSq, queen);
|
||||||
|
default: parseError = true; break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Move(fromSq, toSq);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string formatMove(const Move move) {
|
||||||
|
// redirect parsing to `<<` overload
|
||||||
|
std::ostringstream out;
|
||||||
|
out << move;
|
||||||
|
// return as a string
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// <SAN piece moves> ::= <Piece symbol>[<from file>|<from rank>|<from square>]['x']<to square>
|
||||||
|
// <SAN pawn captures> ::= <from file>[<from rank>] 'x' <to square>[<promoted to>]
|
||||||
|
// <SAN pawn push> ::= <to square>[<promoted to>]
|
||||||
|
Move parseSAN(const std::string& str, const Board& pos, bool& parseError) {
|
||||||
|
// check length, at least 2 characters for pawn pushes and max 6 plus 2
|
||||||
|
// additional characters as move annotations (for example '!?', '+' or '#').
|
||||||
|
if ((str.length() < 2) || (9 < str.length())) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with extracting the piece type.
|
||||||
|
// The first letter (strictly speaking upper case) gives the moving
|
||||||
|
// piece type or a pawn if not given.
|
||||||
|
enum piece color = pos.isWhiteTurn() ? white : black;
|
||||||
|
enum piece piece = pawn;
|
||||||
|
switch(str[0]) {
|
||||||
|
// Piece moves (all normal piece moves)
|
||||||
|
case 'R': piece = rook; break;
|
||||||
|
case 'N': piece = knight; break;
|
||||||
|
case 'B': piece = bishop; break;
|
||||||
|
case 'Q': piece = queen; break;
|
||||||
|
case 'K': piece = king; break;
|
||||||
|
// Castling
|
||||||
|
case 'O': case '0':
|
||||||
|
if (startsWith(str, "O-O-O") || startsWith(str, "0-0-0")) {
|
||||||
|
if (color == white) {
|
||||||
|
return Move(e1, c1);
|
||||||
|
} else {
|
||||||
|
return Move(e8, c8);
|
||||||
|
}
|
||||||
|
} else if (startsWith(str, "O-O") || startsWith(str, "0-0")) {
|
||||||
|
if (color == white) {
|
||||||
|
return Move(e1, g1);
|
||||||
|
} else {
|
||||||
|
return Move(e8, g8);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
// Pawn moves always begin with from/to file
|
||||||
|
case 'a': case 'b': case 'c': case 'd':
|
||||||
|
case 'e': case 'f': case 'g': case 'h':
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize moves.
|
||||||
|
// Drop annotation symbols '?', '+', ... as well as '=', '/' for promotions.
|
||||||
|
// Thats quite generes but as long as the reminder leads to pseudo legal
|
||||||
|
// moves, I'm happy.
|
||||||
|
char part[6];
|
||||||
|
int part_length = 0;
|
||||||
|
for (char c : str) {
|
||||||
|
switch (c) {
|
||||||
|
case '+': case '#': case '!': case '?': case '=': case '/':
|
||||||
|
continue;
|
||||||
|
default:
|
||||||
|
part[part_length++] = c;
|
||||||
|
}
|
||||||
|
if (part_length >= 6) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (piece == pawn) {
|
||||||
|
// Pawn push
|
||||||
|
if ((part_length == 2) || (part_length == 3)) {
|
||||||
|
Index toSq = parseSquare(std::string{part[0], part[1]}, parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
u64 from;
|
||||||
|
if (color == white) {
|
||||||
|
from = fill7dump<Down>(bitMask<Square>(toSq), ~pos.bb(pawn));
|
||||||
|
from &= pos.bb(white);
|
||||||
|
} else {
|
||||||
|
from = fill7dump<Up>(bitMask<Square>(toSq), ~pos.bb(pawn));
|
||||||
|
from &= pos.bb(black);
|
||||||
|
}
|
||||||
|
if (bitCount(from) != 1) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
Index fromSq = bitScanLS(from);
|
||||||
|
if (part_length == 3) {
|
||||||
|
switch(toLower(part[2])) {
|
||||||
|
case 'q': return Move(fromSq, toSq, queen);
|
||||||
|
case 'r': return Move(fromSq, toSq, rook);
|
||||||
|
case 'b': return Move(fromSq, toSq, bishop);
|
||||||
|
case 'n': return Move(fromSq, toSq, knight);
|
||||||
|
default:
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Move(fromSq, toSq);
|
||||||
|
// Pawn Capture
|
||||||
|
} else if ((part_length == 4) || (part_length == 5)) {
|
||||||
|
if (part[1] != 'x') {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
Index toSq = parseSquare(std::string{part[2], part[3]}, parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
u64 from;
|
||||||
|
if (('a' <= part[0]) && (part[0] <= 'h')) {
|
||||||
|
from = bitMask<File>(static_cast<Index>(str[0] - 'a'));
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
if (color == white) {
|
||||||
|
from &= shift<Down>(bitMask<Rank>(toSq));
|
||||||
|
from &= pos.bb(white) & pos.bb(pawn);
|
||||||
|
} else {
|
||||||
|
from &= shift<Up>(bitMask<Rank>(toSq));
|
||||||
|
from &= pos.bb(black) & pos.bb(pawn);
|
||||||
|
}
|
||||||
|
if (bitCount(from) != 1) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
Index fromSq = bitScanLS(from);
|
||||||
|
if (part_length == 5) {
|
||||||
|
switch(toLower(part[4])) {
|
||||||
|
case 'q': return Move(fromSq, toSq, queen);
|
||||||
|
case 'r': return Move(fromSq, toSq, rook);
|
||||||
|
case 'b': return Move(fromSq, toSq, bishop);
|
||||||
|
case 'n': return Move(fromSq, toSq, knight);
|
||||||
|
default:
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Move(fromSq, toSq);
|
||||||
|
}
|
||||||
|
// Illegal length, ether 2, 3 for pawn push or 4, 5 for pawn captures
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only normal piece moves remain (castling and pawn moves handled above)
|
||||||
|
if ((part_length < 3) || (6 < part_length)) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
Index toSq = parseSquare(
|
||||||
|
std::string{part[part_length - 2], part[part_length - 1]}, parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
u64 to = bitMask<Square>(toSq);
|
||||||
|
u64 from = pos.bb(color) & pos.bb(piece);
|
||||||
|
u64 empty = ~(pos.bb(white) | pos.bb(black));
|
||||||
|
switch(piece) {
|
||||||
|
case queen:
|
||||||
|
from &= fill7dump<Left> (to, empty)
|
||||||
|
| fill7dump<Right> (to, empty)
|
||||||
|
| fill7dump<Up> (to, empty)
|
||||||
|
| fill7dump<Down> (to, empty)
|
||||||
|
| fill7dump<LeftUp> (to, empty)
|
||||||
|
| fill7dump<RightUp> (to, empty)
|
||||||
|
| fill7dump<LeftDown> (to, empty)
|
||||||
|
| fill7dump<RightDown>(to, empty);
|
||||||
|
break;
|
||||||
|
case rook:
|
||||||
|
from &= fill7dump<Left> (to, empty)
|
||||||
|
| fill7dump<Right>(to, empty)
|
||||||
|
| fill7dump<Up> (to, empty)
|
||||||
|
| fill7dump<Down> (to, empty);
|
||||||
|
break;
|
||||||
|
case bishop:
|
||||||
|
from &= fill7dump<LeftUp> (to, empty)
|
||||||
|
| fill7dump<RightUp> (to, empty)
|
||||||
|
| fill7dump<LeftDown> (to, empty)
|
||||||
|
| fill7dump<RightDown>(to, empty);
|
||||||
|
break;
|
||||||
|
case knight:
|
||||||
|
from &= pos.knightMoves(toSq);
|
||||||
|
break;
|
||||||
|
case king:
|
||||||
|
from &= pos.kingMoves(toSq);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false && "Unreachable!");
|
||||||
|
}
|
||||||
|
// Capture Move
|
||||||
|
if (part[part_length - 3] == 'x') {
|
||||||
|
Index fromSq;
|
||||||
|
switch (part_length) {
|
||||||
|
case 4:
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
if (('1' <= part[1]) && (part[1] <= '8')) {
|
||||||
|
from &= bitMask<Rank>(8 * static_cast<Index>('8' - part[1]));
|
||||||
|
} else if (('a' <= part[1]) && (part[1] <= 'h')) {
|
||||||
|
from &= bitMask<File>(static_cast<Index>(part[1] - 'a'));
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
fromSq = parseSquare(std::string{part[1], part[2]}, parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
from &= bitMask<Square>(fromSq);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
// Non Capture / Quiet Move
|
||||||
|
} else {
|
||||||
|
Index fromSq;
|
||||||
|
switch (part_length) {
|
||||||
|
case 3:
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
if (('1' <= part[1]) && (part[1] <= '8')) {
|
||||||
|
from &= bitMask<Rank>(8 * static_cast<Index>('8' - part[1]));
|
||||||
|
} else if (('a' <= part[1]) && (part[1] <= 'h')) {
|
||||||
|
from &= bitMask<File>(static_cast<Index>(part[1] - 'a'));
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
fromSq = parseSquare(std::string{part[1], part[2]}, parseError);
|
||||||
|
if (parseError) { return Move(0); }
|
||||||
|
from &= bitMask<Square>(fromSq);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bitCount(from) != 1) {
|
||||||
|
// If there is still an umbiguity, some candidates might be pinned.
|
||||||
|
// Remove pinned candidates unable to reach the target square.
|
||||||
|
u64 cKing = pos.bb(color) & pos.bb(king);
|
||||||
|
Index cKingSq = bitScanLS(cKing);
|
||||||
|
u64 pinMask = pos.pinned((color == white) ? black : white, cKing, empty);
|
||||||
|
for (u64 can = from; can; can &= can - 1) { // can ... candidates
|
||||||
|
Index canSq = bitScanLS(can);
|
||||||
|
// check if current candidate is pinned
|
||||||
|
if ((can & -can) & pinMask) {
|
||||||
|
// validate if tha candidate is able to reach the target square
|
||||||
|
if (!(lineMask(cKingSq, canSq) & to)) {
|
||||||
|
// Remove candidate
|
||||||
|
from ^= can & -can;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If there are still multiple candidates or none -> parse error
|
||||||
|
if (bitCount(from) != 1) {
|
||||||
|
parseError = true;
|
||||||
|
return Move(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Finaly, create the move
|
||||||
|
return Move(bitScanLS(from), toSq);
|
||||||
|
}
|
||||||
|
|
||||||
|
void parseMoves(std::istream& istream, const Board& pos, MoveList& moveList,
|
||||||
|
bool& parseError
|
||||||
|
) {
|
||||||
|
// setup movelist of all legal moves to validate move lagality
|
||||||
|
MoveList legalMoves;
|
||||||
|
pos.moves(legalMoves);
|
||||||
|
// holds input tokens to be parsed as moves
|
||||||
|
std::string word;
|
||||||
|
istream >> std::skipws;
|
||||||
|
// For each input token
|
||||||
|
while (istream >> word) {
|
||||||
|
// parse token as a move (depending on the move format)
|
||||||
|
Move move;
|
||||||
|
if (readSAN) {
|
||||||
|
move = parseSAN(word, pos, parseError);
|
||||||
|
} else {
|
||||||
|
move = parseMove(word, parseError);
|
||||||
|
}
|
||||||
|
// stop in case of a parse error
|
||||||
|
if (parseError) { return; }
|
||||||
|
// validate legality of the (successfully parsed) move and add the
|
||||||
|
// generated (not the parsed move) to the move list containing all the
|
||||||
|
// additional move information
|
||||||
|
bool isLegal = false;
|
||||||
|
for (Move legal : legalMoves) {
|
||||||
|
if ((move.from() == legal.from())
|
||||||
|
&& (move.to() == legal.to())
|
||||||
|
&& (move.promote() == legal.promote()))
|
||||||
|
{
|
||||||
|
isLegal = true;
|
||||||
|
moveList.push_back(legal);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if the parse move was not found in the legal moves list stop and
|
||||||
|
// report a parse error
|
||||||
|
if (!isLegal) {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
char formatPiece(enum piece piece) {
|
||||||
|
switch (piece) {
|
||||||
|
case rook: return 'R';
|
||||||
|
case knight: return 'N';
|
||||||
|
case bishop: return 'B';
|
||||||
|
case queen: return 'Q';
|
||||||
|
case king: return 'K';
|
||||||
|
case pawn: return 'P';
|
||||||
|
default: return '?';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void printBoard(const Board& board) {
|
||||||
|
osyncstream out(cout);
|
||||||
|
|
||||||
|
const char* rankNames = "87654321";
|
||||||
|
const char* fileNames = "abcdefgh";
|
||||||
|
|
||||||
|
for (Index line = 0; line < 17; line++) {
|
||||||
|
if (line % 2) {
|
||||||
|
Index rankIndex = line / 2;
|
||||||
|
out << rankNames[rankIndex];
|
||||||
|
for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
|
||||||
|
Index squareIndex = 8 * rankIndex + fileIndex;
|
||||||
|
if (board.color(squareIndex) == black) {
|
||||||
|
out << " | \033[1m\033[94m"
|
||||||
|
<< formatPiece(board.piece(squareIndex)) << "\033[0m";
|
||||||
|
} else if (board.color(squareIndex) == white) {
|
||||||
|
out << " | \033[1m\033[97m"
|
||||||
|
<< formatPiece(board.piece(squareIndex)) << "\033[0m";
|
||||||
|
} else {
|
||||||
|
out << " | ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out << " |";
|
||||||
|
} else {
|
||||||
|
out << " +---+---+---+---+---+---+---+---+";
|
||||||
|
}
|
||||||
|
out << " ";
|
||||||
|
|
||||||
|
switch (line) {
|
||||||
|
case 1:
|
||||||
|
out << "How's turn: "
|
||||||
|
<< (board.isWhiteTurn() ? "white" : "black");
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
out << "Move Count: " << ((board.plyCount() + 1U) / 2U);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
out << "Half move clock: " << board.halfMoveClock();
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
out << "Castling Rights: ";
|
||||||
|
if (board.castleRight(white, KingSide)) { out << 'K'; };
|
||||||
|
if (board.castleRight(white, QueenSide)) { out << 'Q'; };
|
||||||
|
if (board.castleRight(black, KingSide)) { out << 'k'; };
|
||||||
|
if (board.castleRight(black, QueenSide)) { out << 'q'; };
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
out << "en-passange target: ";
|
||||||
|
if (board.enPassant() < 64) {
|
||||||
|
out << static_cast<char>('a' + (board.enPassant() % 8))
|
||||||
|
<< static_cast<char>('8' - (board.enPassant() / 8));
|
||||||
|
}
|
||||||
|
else { out << "-"; };
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
out << "evaluate: " << std::dec << board.evaluate();
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
out << "hash: " << std::hex << board.hash() << std::dec;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
out << "\033[0K\n"; // clear rest of line (remove potential leftovers)
|
||||||
|
}
|
||||||
|
out << " ";
|
||||||
|
for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
|
||||||
|
out << " " << fileNames[fileIndex];
|
||||||
|
}
|
||||||
|
out << std::endl;
|
||||||
|
}
|
||||||
|
void printBitBoards(const Board& board) {
|
||||||
|
osyncstream out(cout);
|
||||||
|
|
||||||
|
std::string lRanks(" 8\n 7\n 6\n 5\n 4\n 3\n 2\n 1");
|
||||||
|
std::string rRanks("8\n7\n6\n5\n4\n3\n2\n1");
|
||||||
|
|
||||||
|
out << std::endl << " "
|
||||||
|
<< std::left << std::setw(18) << "white"
|
||||||
|
<< std::left << std::setw(18) << "black"
|
||||||
|
<< std::left << std::setw(18) << "pawns"
|
||||||
|
<< std::left << std::setw(18) << "kings"
|
||||||
|
<< std::endl << " " << std::setfill('0') << std::hex
|
||||||
|
<< std::right << std::setw(16) << board.bb(white) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(black) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(pawn) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(king) << " "
|
||||||
|
<< std::left << std::setfill(' ') << std::dec << std::endl
|
||||||
|
<< aside(
|
||||||
|
lRanks,
|
||||||
|
rbits(board.bb(white), '\n', ' '),
|
||||||
|
rbits(board.bb(black), '\n', ' '),
|
||||||
|
rbits(board.bb(pawn), '\n', ' '),
|
||||||
|
rbits(board.bb(king), '\n', ' '),
|
||||||
|
rRanks
|
||||||
|
) << std::endl
|
||||||
|
<< " a b c d e f g h a b c d e f g h"
|
||||||
|
<< " a b c d e f g h a b c d e f g h" << std::endl;
|
||||||
|
out << std::endl << " "
|
||||||
|
<< std::left << std::setw(18) << "rooks"
|
||||||
|
<< std::left << std::setw(18) << "knights"
|
||||||
|
<< std::left << std::setw(18) << "bishops"
|
||||||
|
<< std::left << std::setw(18) << "queens"
|
||||||
|
<< std::endl << " " << std::setfill('0') << std::hex
|
||||||
|
<< std::right << std::setw(16) << board.bb(rook) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(knight) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(bishop) << " "
|
||||||
|
<< std::right << std::setw(16) << board.bb(queen) << " "
|
||||||
|
<< std::left << std::setfill(' ') << std::dec << std::endl
|
||||||
|
<< aside(
|
||||||
|
lRanks,
|
||||||
|
rbits(board.bb(rook), '\n', ' '),
|
||||||
|
rbits(board.bb(knight), '\n', ' '),
|
||||||
|
rbits(board.bb(bishop), '\n', ' '),
|
||||||
|
rbits(board.bb(queen), '\n', ' '),
|
||||||
|
rRanks
|
||||||
|
) << std::endl
|
||||||
|
<< " a b c d e f g h a b c d e f g h"
|
||||||
|
<< " a b c d e f g h a b c d e f g h" << std::endl;
|
||||||
|
|
||||||
|
const enum piece color = board.isWhiteTurn() ? white : black;
|
||||||
|
const enum piece enemy = (color == white) ? black : white;
|
||||||
|
const u64 empty = ~(board.bb(white) | board.bb(black));
|
||||||
|
const u64 cKing = board.bb(color) & board.bb(king);
|
||||||
|
|
||||||
|
out << std::endl << " "
|
||||||
|
<< std::left << std::setw(18) << "attacks"
|
||||||
|
<< std::left << std::setw(18) << "pinned"
|
||||||
|
<< std::left << std::setw(18) << "checkers"
|
||||||
|
<< std::endl
|
||||||
|
<< aside(
|
||||||
|
lRanks,
|
||||||
|
rbits(board.attacks(enemy, empty), '\n', ' '),
|
||||||
|
rbits(board.pinned(enemy, cKing, empty), '\n', ' '),
|
||||||
|
rbits(board.checkers(enemy, cKing, empty), '\n', ' '),
|
||||||
|
rRanks
|
||||||
|
) << std::endl
|
||||||
|
<< " a b c d e f g h a b c d e f g h"
|
||||||
|
<< " a b c d e f g h" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printEval(const Board& board) {
|
||||||
|
osyncstream out(cout);
|
||||||
|
|
||||||
|
// Partial Scores
|
||||||
|
Score pScoreWhite, pScoreBlack;
|
||||||
|
|
||||||
|
out << " | White | Black | Total\n"
|
||||||
|
<< "-------------+---------+---------+---------\n"
|
||||||
|
<< " Material | ";
|
||||||
|
pScoreWhite = board.evalMaterial(white);
|
||||||
|
pScoreBlack = board.evalMaterial(black);
|
||||||
|
out << std::setw(7) << std::right << pScoreWhite << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreBlack << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
|
||||||
|
<< " Pawns | ";
|
||||||
|
pScoreWhite = board.evalPawns(white);
|
||||||
|
pScoreBlack = board.evalPawns(black);
|
||||||
|
out << std::setw(7) << std::right << pScoreWhite << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreBlack << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
|
||||||
|
<< " King Safety | ";
|
||||||
|
pScoreWhite = board.evalKingSafety(white);
|
||||||
|
pScoreBlack = board.evalKingSafety(black);
|
||||||
|
out << std::setw(7) << std::right << pScoreWhite << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreBlack << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
|
||||||
|
<< " Rooks | ";
|
||||||
|
pScoreWhite = board.evalRooks(white);
|
||||||
|
pScoreBlack = board.evalRooks(black);
|
||||||
|
out << std::setw(7) << std::right << pScoreWhite << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreBlack << " | "
|
||||||
|
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n";
|
||||||
|
|
||||||
|
out << "\nTotal: " << board.evaluate() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printMoves(const Board& board, bool capturesOnly) {
|
||||||
|
MoveList moveList;
|
||||||
|
board.moves(moveList, capturesOnly);
|
||||||
|
for (Move& move : moveList) {
|
||||||
|
move.setScore(move.calcScore());
|
||||||
|
}
|
||||||
|
std::sort(moveList.begin(), moveList.end());
|
||||||
|
osyncstream out(cout);
|
||||||
|
if (capturesOnly) {
|
||||||
|
out << "info string captures";
|
||||||
|
} else {
|
||||||
|
out << "info string moves";
|
||||||
|
}
|
||||||
|
for (Move& move : moveList) {
|
||||||
|
out << ' ' << move;
|
||||||
|
}
|
||||||
|
out << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// position [fen <fenstring> | startpos ] moves <move1> <move2> .... <moveN>
|
||||||
|
// \______________________ remainder in cmd ______________________/
|
||||||
|
void position(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
|
||||||
|
std::string word;
|
||||||
|
|
||||||
|
// setup a new game in case of a parse error (no changes in case of an error)
|
||||||
|
std::vector<Board> newGame;
|
||||||
|
|
||||||
|
// Setup position (and consume moves cmd)
|
||||||
|
Board pos;
|
||||||
|
cmd >> word;
|
||||||
|
if (word == "startpos") {
|
||||||
|
// Consume and check "moves" word
|
||||||
|
if (cmd >> word && word != "moves") {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// set position to start position
|
||||||
|
pos = Board();
|
||||||
|
newGame.push_back(pos);
|
||||||
|
} else if (word == "fen") {
|
||||||
|
std::string fen;
|
||||||
|
while (cmd >> word && word != "moves") {
|
||||||
|
fen += word + " ";
|
||||||
|
}
|
||||||
|
pos.init(fen, parseError);
|
||||||
|
if (parseError) { return; }
|
||||||
|
newGame.push_back(pos);
|
||||||
|
} else if (word == "this") { // (UCI extention)
|
||||||
|
if (game.empty()) {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
pos = game.back();
|
||||||
|
// Consume and check "moves" word
|
||||||
|
if (cmd >> word && word != "moves") {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// keep the game
|
||||||
|
newGame = game;
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply (legal) moves (if any) and append new positions
|
||||||
|
while (cmd >> word) {
|
||||||
|
Move move;
|
||||||
|
if (readSAN) {
|
||||||
|
move = UCI::parseSAN(word, pos, parseError);
|
||||||
|
} else {
|
||||||
|
move = UCI::parseMove(word, parseError);
|
||||||
|
}
|
||||||
|
// validate legality (and extend the move with piece, victim, ... info)
|
||||||
|
move = pos.isLegal(move); // validate move legality and extend move info
|
||||||
|
if (parseError || !move) {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
pos.make(move);
|
||||||
|
newGame.push_back(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finaly replace game with new game (no errors occured)
|
||||||
|
game = newGame;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setoption(std::istream& cmd, bool& parseError) {
|
||||||
|
std::string word;
|
||||||
|
|
||||||
|
// Consume "name"
|
||||||
|
if ((!(cmd >> word)) || (word != "name")) {
|
||||||
|
parseError = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// read option id
|
||||||
|
std::string id = "";
|
||||||
|
while ((cmd >> word) && (word != "value")) {
|
||||||
|
id += word;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set option value
|
||||||
|
if (id == "Hash") {
|
||||||
|
Index ttSize = 0;
|
||||||
|
cmd >> ttSize;
|
||||||
|
if ((0 < ttSize) && (ttSize < 1025)) {
|
||||||
|
Search::init(ttSize);
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
} else if (id == "readSAN") {
|
||||||
|
cmd >> word;
|
||||||
|
if (word == "true") {
|
||||||
|
readSAN = true;
|
||||||
|
} else if (word == "false") {
|
||||||
|
readSAN = false;
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
} else if (id == "writeLAN") {
|
||||||
|
cmd >> word;
|
||||||
|
if (word == "true") {
|
||||||
|
writeLAN = true;
|
||||||
|
} else if (word == "false") {
|
||||||
|
writeLAN = false;
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// go <cmd>
|
||||||
|
// with <cmd>;
|
||||||
|
// ...
|
||||||
|
// TODO: implement
|
||||||
|
//
|
||||||
|
void go(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
|
||||||
|
std::string word;
|
||||||
|
Search::State config(Search::state);
|
||||||
|
|
||||||
|
std::string IGNORED; // TODO: complete UCI implementation
|
||||||
|
|
||||||
|
while ((cmd >> word) && !parseError) {
|
||||||
|
if (word == "perft") {
|
||||||
|
config.search = Search::State::Type::perft;
|
||||||
|
if (!(cmd >> config.depth)) {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (word == "searchmoves") {
|
||||||
|
parseMoves(cmd, game.back(), config.searchmoves, parseError);
|
||||||
|
}
|
||||||
|
else if (word == "depth") { cmd >> config.depth; }
|
||||||
|
else if (word == "wtime") { cmd >> config.wtime; }
|
||||||
|
else if (word == "btime") { cmd >> config.btime; }
|
||||||
|
|
||||||
|
else if (word == "winc") { cmd >> config.winc; }
|
||||||
|
else if (word == "binc") { cmd >> config.binc; }
|
||||||
|
else if (word == "movestogo") { cmd >> IGNORED; }
|
||||||
|
else if (word == "nodes") { cmd >> IGNORED; }
|
||||||
|
else if (word == "mate") { cmd >> IGNORED; }
|
||||||
|
else if (word == "movetime") { cmd >> config.movetime; }
|
||||||
|
else if (word == "ponder") { cmd >> IGNORED; }
|
||||||
|
else if (word == "infinite") {
|
||||||
|
config.depth = MoveList::max_size(); // max PV line length
|
||||||
|
config.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
config.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
parseError = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (parseError) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write new global settings to Search state
|
||||||
|
Search::state.wtime = config.wtime;
|
||||||
|
Search::state.btime = config.btime;
|
||||||
|
Search::state.winc = config.winc;
|
||||||
|
Search::state.binc = config.binc;
|
||||||
|
|
||||||
|
Search::start(game, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
|
||||||
|
Board& board = game.back();
|
||||||
|
std::string word;
|
||||||
|
|
||||||
|
// Get print command (or set default)
|
||||||
|
if (!(cmd >> word)) {
|
||||||
|
word = "board";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (word == "board") { printBoard(board); }
|
||||||
|
else if (word == "game") { for (Board& pos : game) { printBoard(pos); } }
|
||||||
|
else if (word == "moves") { printMoves(board); }
|
||||||
|
else if (word == "eval") { printEval(board); }
|
||||||
|
else if (word == "captures") { printMoves(board, true); }
|
||||||
|
else if (word == "bits") { printBitBoards(board); }
|
||||||
|
else if (word == "fen") {
|
||||||
|
osyncstream(cout) << "info string fen "
|
||||||
|
<< board.fen() << std::endl; }
|
||||||
|
else { parseError = true; }
|
||||||
|
}
|
||||||
|
|
||||||
|
void stdin_listen(std::vector<Board>& game) {
|
||||||
|
// read line by line from stdin and dispatch appropriate command handler
|
||||||
|
std::string line;
|
||||||
|
while (getline(std::cin, line)) {
|
||||||
|
// a game consists of at least one position
|
||||||
|
assert(game.size());
|
||||||
|
|
||||||
|
std::istringstream cmd(line);
|
||||||
|
std::string cmdName;
|
||||||
|
|
||||||
|
// Extract command name (skip white spaces)
|
||||||
|
cmd >> std::skipws >> cmdName;
|
||||||
|
|
||||||
|
// Dispatch commands (or handle directly)
|
||||||
|
bool parseError = false; // tracks comand argument parse status
|
||||||
|
if (cmdName == "quit") { Search::stop(); break; } // stop stdin loop -> shutdown
|
||||||
|
else if (cmdName == "exit") { Search::stop(); break; } // quit alias
|
||||||
|
else if (cmdName == "stop") { Search::stop(); }
|
||||||
|
else if (cmdName == "ucinewgame") { Search::newgame(); }
|
||||||
|
else if (cmdName == "uci") {
|
||||||
|
osyncstream(cout)
|
||||||
|
<< "id name Schach Hoernchen"
|
||||||
|
"\nid author Daniel Kapla"
|
||||||
|
"\noption name Hash type spin default 32 min 1 max 1024"
|
||||||
|
"\noption name readSAN type check default false"
|
||||||
|
"\noption name writeLAN type check default false"
|
||||||
|
"\nuciok" // Ready (there are no options yet)
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
else if (cmdName == "isready") { cout << "readyok" << std::endl; }
|
||||||
|
else if (cmdName == "setoption") { setoption(cmd, parseError); }
|
||||||
|
else if (cmdName == "go") { go(game, cmd, parseError); }
|
||||||
|
else if (cmdName == "position") { position(game, cmd, parseError); }
|
||||||
|
// UCI Extention (not part of the UCI protocol)
|
||||||
|
else if (cmdName == "d") { print(game, cmd, parseError); } // print alias (as in stockfish)
|
||||||
|
else if (cmdName == "print") { print(game, cmd, parseError); }
|
||||||
|
else if (cmdName == "getoptions") {
|
||||||
|
osyncstream(cout)
|
||||||
|
<< "option name Hash value " << std::setprecision(2)
|
||||||
|
<< Search::ttSize() << " MB"
|
||||||
|
<< "\noption name readSAN value " << std::boolalpha << readSAN
|
||||||
|
<< "\noption name writeLAN value " << std::boolalpha << writeLAN
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
else if (cmdName == "help" || cmdName == "?") {
|
||||||
|
osyncstream(std::cerr)
|
||||||
|
<< "Commands:\n"
|
||||||
|
" uci\n\033[2m"
|
||||||
|
" responds with self identification/options and finishes "
|
||||||
|
"with uciok\033[0m\n"
|
||||||
|
" ucinewgame\n\033[2m"
|
||||||
|
" starts a new game; resets all internal game states\033[0m\n"
|
||||||
|
" isready\n\033[2m"
|
||||||
|
" Should be responding 'readyok' immediately\033[0m\n"
|
||||||
|
" position\n"
|
||||||
|
" position startpos [moves <move1> [<move2> ...]]\n"
|
||||||
|
" position fen <fen> [moves <move1> [<move2> ...]]\n"
|
||||||
|
" position this moves <move1> [<move2> ...]\n"
|
||||||
|
" go\n"
|
||||||
|
" go perft <depth>\n"
|
||||||
|
" go [depth <depth>] [searchmoves <move> [<move> ...]]\n"
|
||||||
|
" stop\n\033[2m"
|
||||||
|
" Stops what ever the engine is doing right now as "
|
||||||
|
"soon as possible\033[0m\n"
|
||||||
|
" print\n"
|
||||||
|
" d\n"
|
||||||
|
" print [board]\033[2m\n"
|
||||||
|
" Prity prints the board with state information\033[0m\n"
|
||||||
|
" print game\033[2m\n"
|
||||||
|
" Prity prints entire game history\033[0m\n"
|
||||||
|
" print moves\033[2m\n"
|
||||||
|
" Gives a list of all legal moves sorted by move ordering "
|
||||||
|
"heuristic\033[0m\n"
|
||||||
|
" print captures\033[2m\n"
|
||||||
|
" Same as print moves but captures only\033[0m\n"
|
||||||
|
" print bits\033[2m\n"
|
||||||
|
" Prints the internal bit-boards plus attacks, pinnned and "
|
||||||
|
"checkers bit-boards\033[0m\n"
|
||||||
|
" print fen\033[2m\n"
|
||||||
|
" Print FEN string of current internal state\033[0m\n"
|
||||||
|
" quit\n"
|
||||||
|
" exit\033[2m\n"
|
||||||
|
" Shutdown the engine as soon as possible\033[0m\n"
|
||||||
|
" setoption\n"
|
||||||
|
" setoption name Hash value <size>\n"
|
||||||
|
" setoption name readSAN value [true | false]\n"
|
||||||
|
" setoption name writeLAN value [true | false]\n"
|
||||||
|
" getoptions\n\033[2m"
|
||||||
|
" Prints set values of all options\033[0m\n"
|
||||||
|
<< std::endl; // TODO: complete help!!!
|
||||||
|
} else {
|
||||||
|
osyncstream(std::cerr)
|
||||||
|
<< "info string error Unknown command!" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parseError) {
|
||||||
|
osyncstream(std::cerr)
|
||||||
|
<< "info string error Missformed or illegal command!"
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
osyncstream(cout) << "info string shutdown" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} /* namespace UCI */
|
|
@ -0,0 +1,103 @@
|
||||||
|
#ifndef UCI_GUARD_MOVE_H
|
||||||
|
#define UCI_GUARD_MOVE_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
|
||||||
|
// Forward declarations
|
||||||
|
class Board;
|
||||||
|
class Move;
|
||||||
|
class MoveList;
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
ostream& operator<<(ostream& out, const Move& move);
|
||||||
|
ostream& operator<<(ostream& out, const MoveList& moveList);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* UCI ... Universal Chess Interface
|
||||||
|
*
|
||||||
|
* Implements the UCI interface for interaction between the internal engine
|
||||||
|
* and an GUI using the UCI standard.
|
||||||
|
*/
|
||||||
|
namespace UCI {
|
||||||
|
|
||||||
|
// General UCI protocol configuration
|
||||||
|
extern bool readSAN; // if moves should be read in SAN
|
||||||
|
extern bool writeLAN; // moves written in Long Algebraic Notation
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses square in algebraic notation as internal square index.
|
||||||
|
*
|
||||||
|
* @param str string representation of a square /[a-h][1-8]/.
|
||||||
|
* @param parseError output variable set to true if illegal square
|
||||||
|
* representation is encountered (aka. parse error occured).
|
||||||
|
*
|
||||||
|
* @returns index between 0 and 63 if legal, 64 otherwise.
|
||||||
|
*/
|
||||||
|
Index parseSquare(const std::string& str, bool& parseError);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses a move given in pure coordinate notation
|
||||||
|
*
|
||||||
|
* @param str string representation of a move in `<from><to>[<promotion>]` format
|
||||||
|
* @param parseError output variable set to false if successfully parsed, true
|
||||||
|
* otherwise.
|
||||||
|
*/
|
||||||
|
Move parseMove(const std::string& str, bool& parseError);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Formats a move in coordinate notation (Unary Function equiv to the pipe
|
||||||
|
* operator for Move)
|
||||||
|
*
|
||||||
|
* @param move move to be formated as a string
|
||||||
|
*
|
||||||
|
* @return string representation of the move in coordinate notation
|
||||||
|
*/
|
||||||
|
const std::string formatMove(const Move move);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses a move given in standard algebraic notation (SAN)
|
||||||
|
*
|
||||||
|
* @param str string representation of a move in SAN format
|
||||||
|
* @param pos current position, needed for move interpretation
|
||||||
|
* @param parseError output variable set to false if successfully parsed, true
|
||||||
|
* otherwise.
|
||||||
|
*/
|
||||||
|
Move parseSAN(const std::string& str, const Board& pos, bool& parseError);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parses multiple moves from an input stream into a `MoveList`
|
||||||
|
*/
|
||||||
|
void parseMoves(std::istream&, MoveList&, bool&);
|
||||||
|
|
||||||
|
char formatPiece(enum piece piece);
|
||||||
|
void printBoard(const Board&);
|
||||||
|
void printMoves(const Board&, bool = false);
|
||||||
|
void printBitBoards(const Board&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles `position ...` command by setting given board (including moves).
|
||||||
|
*
|
||||||
|
* @param board internal board representation to be manipulated.
|
||||||
|
* @param cmd argument `...` of the position command. For example; given the
|
||||||
|
* command `position startpos moves e2e4`, str should contain `startpos moves e2e4`.
|
||||||
|
* @param parseError output, true if a parse error occured, false otherwise.
|
||||||
|
*/
|
||||||
|
void position(std::vector<Board>&, std::istream&, bool&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* handles `go ...` commands. Starts searches, performance tests, ...
|
||||||
|
*/
|
||||||
|
void go(std::vector<Board>&, std::istream&, bool&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Processes/Parses input from stdin and dispatches appropriate command handler.
|
||||||
|
*/
|
||||||
|
void stdin_listen(std::vector<Board>& game);
|
||||||
|
|
||||||
|
} /* namespace UCI */
|
||||||
|
|
||||||
|
#endif /* UCI_GUARD_MOVE_H */
|
|
@ -0,0 +1,698 @@
|
||||||
|
#ifndef INCLUDE_GUARD_UTILS_H
|
||||||
|
#define INCLUDE_GUARD_UTILS_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdlib> // for strto* (instead of std::strto*)
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
|
#include "types.h"
|
||||||
|
|
||||||
|
// Helpfull in debugging
|
||||||
|
template <typename T>
|
||||||
|
inline std::string rbits(T mask, char sepByte = 0, char sep = 0,
|
||||||
|
Index mark = -1, char marker = 'X'
|
||||||
|
) {
|
||||||
|
std::ostringstream out;
|
||||||
|
|
||||||
|
for (unsigned i = 0; i < sizeof(T); i++) {
|
||||||
|
out << sep;
|
||||||
|
for (unsigned j = 0; j < 8; j++) {
|
||||||
|
Index index = 8 * i + j;
|
||||||
|
if (index == mark) {
|
||||||
|
out << marker;
|
||||||
|
} else {
|
||||||
|
out << ((static_cast<T>(1) << index) & mask ? '1' : '.');
|
||||||
|
}
|
||||||
|
if (sep) out << sep;
|
||||||
|
}
|
||||||
|
if (sepByte && (i < (sizeof(T) - 1))) out << sepByte;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::string bits(T mask, char sepByte = 0, char sep = 0,
|
||||||
|
Index mark = -1, char marker = 'X'
|
||||||
|
) {
|
||||||
|
std::string str(rbits<T>(mask, sepByte, sep, mark, marker));
|
||||||
|
return std::string(str.rbegin(), str.rend());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::vector<std::string> split(const std::string& str, char delim = ' ') {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
std::istringstream iss(str);
|
||||||
|
std::string part;
|
||||||
|
while (std::getline(iss, part, delim)) {
|
||||||
|
if (!part.empty()) {
|
||||||
|
parts.push_back(part);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
inline std::vector<std::string> parse(const std::string& str, const std::string& format) {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
std::string part;
|
||||||
|
unsigned start = 0, end = 0, len = 0;
|
||||||
|
for (const std::string& delim : split(format, '*')) {
|
||||||
|
while (end + delim.length() < str.length()) {
|
||||||
|
for (len = 0; len < delim.length(); len++) {
|
||||||
|
if (str[end + len] != delim[len]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (len == delim.length()) {
|
||||||
|
parts.push_back(str.substr(start, end - start));
|
||||||
|
end += len;
|
||||||
|
start = end;
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
end++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parts.push_back(str.substr(end));
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string aside(std::string lhs, std::string rhs) {
|
||||||
|
std::ostringstream out;
|
||||||
|
|
||||||
|
std::vector<std::string> lLines = split(lhs, '\n');
|
||||||
|
std::vector<std::string> rLines = split(rhs, '\n');
|
||||||
|
|
||||||
|
std::size_t lMax = 0;
|
||||||
|
for (std::size_t i = 0; i < lLines.size(); i++) {
|
||||||
|
lMax = std::max(lMax, lLines[i].length());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (std::size_t i = 0, n = std::min(lLines.size(), rLines.size()); i < n; i++) {
|
||||||
|
out << lLines[i]
|
||||||
|
<< std::string(lMax - lLines[i].length() + 1, ' ')
|
||||||
|
<< rLines[i] << '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lLines.size() < rLines.size()) {
|
||||||
|
for (std::size_t i = lLines.size(); i < rLines.size(); i++) {
|
||||||
|
out << std::string(lMax + 1, ' ')
|
||||||
|
<< rLines[i] << '\n';
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (std::size_t i = rLines.size(); i < lLines.size(); i++) {
|
||||||
|
out << lLines[i] << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
inline std::string aside(std::string col1, std::string col2, Args... cols) {
|
||||||
|
return aside(
|
||||||
|
col1,
|
||||||
|
aside(col2, cols...)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline void fourBitBoards(
|
||||||
|
const std::string& title1, const u64 bb1,
|
||||||
|
const std::string& title2, const u64 bb2,
|
||||||
|
const std::string& title3, const u64 bb3,
|
||||||
|
const std::string& title4, const u64 bb4,
|
||||||
|
Index mark = -1, char marker = 'X'
|
||||||
|
) {
|
||||||
|
std::string lRanks(" 8\n 7\n 6\n 5\n 4\n 3\n 2\n 1");
|
||||||
|
std::string rRanks("8\n7\n6\n5\n4\n3\n2\n1");
|
||||||
|
|
||||||
|
cout << std::endl << " "
|
||||||
|
<< std::left << std::setw(18) << title1
|
||||||
|
<< std::left << std::setw(18) << title2
|
||||||
|
<< std::left << std::setw(18) << title3
|
||||||
|
<< std::left << std::setw(18) << title4
|
||||||
|
<< std::endl
|
||||||
|
<< aside(
|
||||||
|
lRanks,
|
||||||
|
rbits(bb1, '\n', ' ', mark, marker),
|
||||||
|
rbits(bb2, '\n', ' ', mark, marker),
|
||||||
|
rbits(bb3, '\n', ' ', mark, marker),
|
||||||
|
rbits(bb4, '\n', ' ', mark, marker),
|
||||||
|
rRanks
|
||||||
|
) << std::endl
|
||||||
|
<< " a b c d e f g h a b c d e f g h"
|
||||||
|
<< " a b c d e f g h a b c d e f g h" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline char toLower(const char c) {
|
||||||
|
if ('A' <= c && c <= 'Z') {
|
||||||
|
return c + ' ';
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
inline char toUpper(const char c) {
|
||||||
|
if ('a' <= c && c <= 'z') {
|
||||||
|
return c - ' ';
|
||||||
|
}
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
inline bool isUpper(const char c) {
|
||||||
|
return ('A' <= c && c <= 'Z');
|
||||||
|
}
|
||||||
|
inline bool isLower(const char c) {
|
||||||
|
return ('a' <= c && c <= 'z');
|
||||||
|
}
|
||||||
|
inline bool startsWith(const std::string& str, const std::string& prefix) {
|
||||||
|
if (str.length() < prefix.length()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (unsigned i = 0; i < prefix.length(); i++) {
|
||||||
|
if (str[i] != prefix[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: https://stackoverflow.com/questions/194465/how-to-parse-a-string-to-an-int-in-c
|
||||||
|
// see: https://en.cppreference.com/w/cpp/string/byte/strtol
|
||||||
|
// Note: T must be a unsigned type
|
||||||
|
// Only applicable for types smaller/equal to unsigned long long.
|
||||||
|
// (only used for uint8_t, uint16_t, uint32_t and uint64_t)
|
||||||
|
template <typename T>
|
||||||
|
T parseUnsigned(const std::string& str, bool& parseError) {
|
||||||
|
if (str.empty()) {
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<T>(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
char* end;
|
||||||
|
errno = 0;
|
||||||
|
unsigned long long num = strtoull(str.c_str(), &end, 10);
|
||||||
|
|
||||||
|
if (errno == ERANGE) {
|
||||||
|
errno = 0;
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<T>(-1);
|
||||||
|
}
|
||||||
|
if (*end != '\0') {
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<T>(-1);
|
||||||
|
}
|
||||||
|
// check overflow (using complement trick for unsigned integer types)
|
||||||
|
if (num > static_cast<unsigned long long>(static_cast<T>(-1))) {
|
||||||
|
parseError = true;
|
||||||
|
return static_cast<T>(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return static_cast<T>(num);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <enum location>
|
||||||
|
constexpr u64 bitMask();
|
||||||
|
|
||||||
|
template <enum location>
|
||||||
|
constexpr u64 bitMask(Index);
|
||||||
|
|
||||||
|
template <enum location L>
|
||||||
|
inline constexpr u64 bitMask(Index file, Index rank) {
|
||||||
|
return bitMask<L>(8 * rank + file);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Square>(Index sq) {
|
||||||
|
return static_cast<u64>(1) << sq;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<File>(Index sq) {
|
||||||
|
return static_cast<u64>(0x0101010101010101) << (sq & 7);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Rank>(Index sq) {
|
||||||
|
return static_cast<u64>(0xFF) << (sq & 56);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Left>(Index sq) {
|
||||||
|
return bitMask<Rank>(sq) & (bitMask<Square>(sq) - 1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Right>(Index sq) {
|
||||||
|
return bitMask<Rank>(sq) & (static_cast<u64>(-2) << sq);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Up>(Index sq) {
|
||||||
|
return bitMask<File>(sq) & (bitMask<Square>(sq) - 1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Down>(Index sq) {
|
||||||
|
return bitMask<File>(sq) & (static_cast<u64>(-2) << sq);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<Diag>(Index sq) {
|
||||||
|
const u64 diag = static_cast<u64>(0x8040201008040201);
|
||||||
|
int offset = 8 * static_cast<int>(sq & 7) - static_cast<int>(sq & 56);
|
||||||
|
int nort = -offset & ( offset >> 31);
|
||||||
|
int sout = offset & (-offset >> 31);
|
||||||
|
return (diag >> sout) << nort;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<AntiDiag>(Index sq) {
|
||||||
|
const u64 diag = static_cast<u64>(0x0102040810204080);
|
||||||
|
int offset = 56 - 8 * static_cast<int>(sq & 7) - static_cast<int>(sq & 56);
|
||||||
|
int nort = -offset & ( offset >> 31);
|
||||||
|
int sout = offset & (-offset >> 31);
|
||||||
|
return (diag >> sout) << nort;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<WhiteSquares>() {
|
||||||
|
return 0x55AA55AA55AA55AA;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 bitMask<BlackSquares>() {
|
||||||
|
return 0xAA55AA55AA55AA55;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shifts with respect to "off board" shifting
|
||||||
|
*
|
||||||
|
* bits shift<Left>(bits) shift<Up>(bits)
|
||||||
|
* 8 . . . 1 . . . . . . 1 . . . . . . . . . . . . . 8
|
||||||
|
* 7 . . . . . . . . . . . . . . . . . . . . . . . . 7
|
||||||
|
* 6 . . . . . . . . . . . . . . . . 1 . . . . 1 . . 6
|
||||||
|
* 5 1 . . . . 1 . . . . . . 1 . . . . . . . . . . . 5
|
||||||
|
* 4 . . . . . . . . . . . . . . . . . . 1 . . . . . 4
|
||||||
|
* 3 . . 1 . . . . . . 1 . . . . . . . . . . . . . . 3
|
||||||
|
* 2 . . . . . . . . . . . . . . . . . . . . . . . . 2
|
||||||
|
* 1 . . . . . . . . . . . . . . . . . . . . . . . . 1
|
||||||
|
* a b c d e f g h a b c d e f g h a b c d e f g h
|
||||||
|
*/
|
||||||
|
template <enum location>
|
||||||
|
constexpr u64 shift(u64);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<Left>(u64 bits) {
|
||||||
|
return (bits >> 1) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<Right>(u64 bits) {
|
||||||
|
return (bits << 1) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<Down>(u64 bits) {
|
||||||
|
return bits << 8;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<Up>(u64 bits) {
|
||||||
|
return bits >> 8;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<RightUp>(u64 bits) {
|
||||||
|
return (bits >> 7) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<RightDown>(u64 bits) {
|
||||||
|
return (bits << 9) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<LeftUp>(u64 bits) {
|
||||||
|
return (bits >> 9) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 shift<LeftDown>(u64 bits) {
|
||||||
|
return (bits << 7) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fills (including) start square till the end of the board.
|
||||||
|
*/
|
||||||
|
template <enum location>
|
||||||
|
constexpr u64 fill(u64);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* bits fill<File>(bits) fill<Rank>(bits)
|
||||||
|
* 8 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 8
|
||||||
|
* 7 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 7
|
||||||
|
* 6 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 6
|
||||||
|
* 5 . . 1 . . 1 . . . . 1 . . 1 . . 1 1 1 1 1 1 1 1 5
|
||||||
|
* 4 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 4
|
||||||
|
* 3 . . 1 . . . . . . . 1 . . 1 . . 1 1 1 1 1 1 1 1 3
|
||||||
|
* 2 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 2
|
||||||
|
* 1 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 1
|
||||||
|
* a b c d e f g h a b c d e f g h a b c d e f g h
|
||||||
|
*/
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<Up>(u64 bits) {
|
||||||
|
bits |= bits >> 8;
|
||||||
|
bits |= bits >> 16;
|
||||||
|
return bits | (bits >> 32);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<Down>(u64 bits) {
|
||||||
|
bits |= bits << 8;
|
||||||
|
bits |= bits << 16;
|
||||||
|
return bits | (bits << 32);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<Left>(u64 bits) {
|
||||||
|
bits |= ((bits & 0xFEFEFEFEFEFEFEFE) >> 1);
|
||||||
|
bits |= ((bits & 0xFCFCFCFCFCFCFCFC) >> 2);
|
||||||
|
return bits | ((bits & 0xF0F0F0F0F0F0F0F0) >> 4);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<Right>(u64 bits) {
|
||||||
|
bits |= ((bits & 0x7F7F7F7F7F7F7F7F) << 1);
|
||||||
|
bits |= ((bits & 0x3F3F3F3F3F3F3F3F) << 2);
|
||||||
|
return bits | ((bits & 0x0F0F0F0F0F0F0F0F) << 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<File>(u64 bits) {
|
||||||
|
return fill<Up>(bits) | fill<Down>(bits);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline constexpr u64 fill<Rank>(u64 bits) {
|
||||||
|
return fill<Left>(bits) | fill<Right>(bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An attack fill (excludes the attacker but includes blocking pieces)
|
||||||
|
* see: https://www.chessprogramming.org/Dumb7Fill
|
||||||
|
*/
|
||||||
|
template <enum location>
|
||||||
|
inline u64 fill7dump(u64, u64);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Left>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
flood |= attacker = (attacker >> 1) & empty;
|
||||||
|
return (flood >> 1) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Right>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
flood |= attacker = (attacker << 1) & empty;
|
||||||
|
return (flood << 1) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Down>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
flood |= attacker = (attacker << 8) & empty;
|
||||||
|
return flood << 8;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Up>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
flood |= attacker = (attacker >> 8) & empty;
|
||||||
|
return flood >> 8;
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<RightUp>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
flood |= attacker = (attacker >> 7) & empty;
|
||||||
|
return (flood >> 7) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<RightDown>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
flood |= attacker = (attacker << 9) & empty;
|
||||||
|
return (flood << 9) & ~bitMask<File>(a1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<LeftUp>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
flood |= attacker = (attacker >> 9) & empty;
|
||||||
|
return (flood >> 9) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<LeftDown>(u64 attacker, u64 empty) {
|
||||||
|
u64 flood = attacker;
|
||||||
|
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
flood |= attacker = (attacker << 7) & empty;
|
||||||
|
return (flood << 7) & ~bitMask<File>(h1);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Plus>(u64 attacker, u64 empty) {
|
||||||
|
return fill7dump<Up> (attacker, empty)
|
||||||
|
| fill7dump<Down> (attacker, empty)
|
||||||
|
| fill7dump<Left> (attacker, empty)
|
||||||
|
| fill7dump<Right>(attacker, empty);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Cross>(u64 attacker, u64 empty) {
|
||||||
|
return fill7dump<LeftUp> (attacker, empty)
|
||||||
|
| fill7dump<LeftDown> (attacker, empty)
|
||||||
|
| fill7dump<RightUp> (attacker, empty)
|
||||||
|
| fill7dump<RightDown>(attacker, empty);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline u64 fill7dump<Star>(u64 attacker, u64 empty) {
|
||||||
|
return fill7dump<Plus> (attacker, empty)
|
||||||
|
| fill7dump<Cross>(attacker, empty);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Index fileIndex(const Index squareIndex) {
|
||||||
|
assert(squareIndex < 64);
|
||||||
|
|
||||||
|
return squareIndex & 7;
|
||||||
|
}
|
||||||
|
inline Index rankIndex(const Index squareIndex) {
|
||||||
|
assert(squareIndex < 64);
|
||||||
|
|
||||||
|
return squareIndex >> 3;
|
||||||
|
}
|
||||||
|
inline Index squareIndex(const Index fileIndex, const Index rankIndex) {
|
||||||
|
assert(fileIndex < 8);
|
||||||
|
assert(rankIndex < 8);
|
||||||
|
|
||||||
|
return 8 * rankIndex + fileIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline u64 lineMask(Index sq1, Index sq2) {
|
||||||
|
int r1 = static_cast<int>(rankIndex(sq1));
|
||||||
|
int r2 = static_cast<int>(rankIndex(sq2));
|
||||||
|
|
||||||
|
int f1 = static_cast<int>(fileIndex(sq1));
|
||||||
|
int f2 = static_cast<int>(fileIndex(sq2));
|
||||||
|
|
||||||
|
assert((-1 < r1) && (r1 < 8));
|
||||||
|
assert((-1 < r2) && (r2 < 8));
|
||||||
|
assert((-1 < f1) && (f1 < 8));
|
||||||
|
assert((-1 < f2) && (f2 < 8));
|
||||||
|
|
||||||
|
if (r1 == r2) { return bitMask<Rank>(sq1); }
|
||||||
|
if (f1 == f2) { return bitMask<File>(sq1); }
|
||||||
|
if ((r1 - r2) == (f1 - f2)) { return bitMask<Diag>(sq1); }
|
||||||
|
if ((r1 - r2) == (f2 - f1)) { return bitMask<AntiDiag>(sq1); }
|
||||||
|
|
||||||
|
assert(false);
|
||||||
|
|
||||||
|
return bitMask<Square>(sq1) | bitMask<Square>(sq2);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline u64 bitReverse(u64 x) {
|
||||||
|
x = (((x & static_cast<u64>(0xAAAAAAAAAAAAAAAA)) >> 1)
|
||||||
|
| ((x & static_cast<u64>(0x5555555555555555)) << 1));
|
||||||
|
x = (((x & static_cast<u64>(0xCCCCCCCCCCCCCCCC)) >> 2)
|
||||||
|
| ((x & static_cast<u64>(0x3333333333333333)) << 2));
|
||||||
|
x = (((x & static_cast<u64>(0xF0F0F0F0F0F0F0F0)) >> 4)
|
||||||
|
| ((x & static_cast<u64>(0x0F0F0F0F0F0F0F0F)) << 4));
|
||||||
|
#ifdef __GNUC__
|
||||||
|
return __builtin_bswap64(x);
|
||||||
|
#else
|
||||||
|
x = (((x & static_cast<u64>(0xFF00FF00FF00FF00)) >> 8)
|
||||||
|
| ((x & static_cast<u64>(0x00FF00FF00FF00FF)) << 8));
|
||||||
|
x = (((x & static_cast<u64>(0xFFFF0000FFFF0000)) >> 16)
|
||||||
|
| ((x & static_cast<u64>(0x0000FFFF0000FFFF)) << 16));
|
||||||
|
return((x >> 32) | (x << 32));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <enum location>
|
||||||
|
inline u64 bitFlip(u64 x);
|
||||||
|
|
||||||
|
// Reverses byte order, aka rank 8 <-> rank 1, rank 7 <-> rank 2, ...
|
||||||
|
template <>
|
||||||
|
inline u64 bitFlip<Rank>(u64 x) {
|
||||||
|
#ifdef __GNUC__
|
||||||
|
return __builtin_bswap64(x);
|
||||||
|
#elif
|
||||||
|
x = (((x & static_cast<u64>(0xFF00FF00FF00FF00)) >> 8)
|
||||||
|
| ((x & static_cast<u64>(0x00FF00FF00FF00FF)) << 8));
|
||||||
|
x = (((x & static_cast<u64>(0xFFFF0000FFFF0000)) >> 16)
|
||||||
|
| ((x & static_cast<u64>(0x0000FFFF0000FFFF)) << 16));
|
||||||
|
return((x >> 32) | (x << 32));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverses bits in bytes, aka file a <-> file h, file b <-> file g, ...
|
||||||
|
template <>
|
||||||
|
inline u64 bitFlip<File> (u64 x) {
|
||||||
|
constexpr u64 a = 0x5555555555555555;
|
||||||
|
constexpr u64 b = 0x3333333333333333;
|
||||||
|
constexpr u64 c = 0x0F0F0F0F0F0F0F0F;
|
||||||
|
|
||||||
|
x = ((x >> 1) & a) | ((x & a) << 1);
|
||||||
|
x = ((x >> 2) & b) | ((x & b) << 2);
|
||||||
|
x = ((x >> 4) & c) | ((x & c) << 4);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// see: https://chessprogramming.org/Flipping_Mirroring_and_Rotating
|
||||||
|
|
||||||
|
// Flips bits about the diagonal (transposition of the board)
|
||||||
|
template <>
|
||||||
|
inline u64 bitFlip<Diag>(u64 x) {
|
||||||
|
u64 t; // Temporary Value
|
||||||
|
constexpr u64 a = 0x5500550055005500;
|
||||||
|
constexpr u64 b = 0x3333000033330000;
|
||||||
|
constexpr u64 c = 0x0F0F0F0F00000000;
|
||||||
|
|
||||||
|
t = c & (x ^ (x << 28));
|
||||||
|
x ^= t ^ (t >> 28);
|
||||||
|
t = b & (x ^ (x << 14));
|
||||||
|
x ^= t ^ (t >> 14);
|
||||||
|
t = a & (x ^ (x << 7));
|
||||||
|
x ^= t ^ (t >> 7);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flips bits about the anti-diagonal
|
||||||
|
template <>
|
||||||
|
inline u64 bitFlip<AntiDiag>(u64 x) {
|
||||||
|
u64 t; // Temporary Value
|
||||||
|
constexpr u64 a = 0xAA00AA00AA00AA00;
|
||||||
|
constexpr u64 b = 0xCCCC0000CCCC0000;
|
||||||
|
constexpr u64 c = 0xF0F0F0F00F0F0F0F;
|
||||||
|
|
||||||
|
t = x ^ (x << 36);
|
||||||
|
x ^= c & (t ^ (x >> 36));
|
||||||
|
t = b & (x ^ (x << 18));
|
||||||
|
x ^= t ^ (t >> 18);
|
||||||
|
t = a & (x ^ (x << 9));
|
||||||
|
x ^= t ^ (t >> 9);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Index bitCount(u64 x) {
|
||||||
|
#ifdef __GNUC__
|
||||||
|
return __builtin_popcountll(x); // `POPulation COUNT` (Long Long)
|
||||||
|
#elif
|
||||||
|
Index count = 0; // counts set bits
|
||||||
|
|
||||||
|
// increment count until there are no bits set in x
|
||||||
|
for (; x; count++) {
|
||||||
|
x &= x - 1; // unset least significant bit
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __GNUC__
|
||||||
|
inline Index bitScanLS(u64 bb) {
|
||||||
|
return __builtin_ctzll(bb); // Count Trailing Zeros (Long Long)
|
||||||
|
}
|
||||||
|
#elif
|
||||||
|
/**
|
||||||
|
* `de Brujin` sequence and lookup for bitScanLS and bitScanMS.
|
||||||
|
*/
|
||||||
|
constexpr u64 debruijn64Seq = static_cast<u64>(0x03f79d71b4cb0a89);
|
||||||
|
constexpr Index debruijn64Lookup[64] = {
|
||||||
|
0, 47, 1, 56, 48, 27, 2, 60,
|
||||||
|
57, 49, 41, 37, 28, 16, 3, 61,
|
||||||
|
54, 58, 35, 52, 50, 42, 21, 44,
|
||||||
|
38, 32, 29, 23, 17, 11, 4, 62,
|
||||||
|
46, 55, 26, 59, 40, 36, 15, 53,
|
||||||
|
34, 51, 20, 43, 31, 22, 10, 45,
|
||||||
|
25, 39, 14, 33, 19, 30, 9, 24,
|
||||||
|
13, 18, 8, 12, 7, 6, 5, 63
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* Gets the least significant 1 bit `bitScanLS` and most significant
|
||||||
|
* `bitScanMS` index on a 64-bit board.
|
||||||
|
*
|
||||||
|
* Using a `de Brujin` sequence to index a 1 in a 64-bit word.
|
||||||
|
*
|
||||||
|
* @param `bb` 64-bit word (a bit board)
|
||||||
|
* @condition `bb` != 0
|
||||||
|
* @return index 0 to 63 of least significant one bit
|
||||||
|
*
|
||||||
|
* @see Original authors: `Martin Läuter (1997), Charles E. Leiserson,`
|
||||||
|
* `Harald Prokop, Keith H. Randall`
|
||||||
|
* @see https://www.chessprogramming.org/BitScan
|
||||||
|
*/
|
||||||
|
inline Index bitScanLS(u64 bb) {
|
||||||
|
assert(bb != 0);
|
||||||
|
return debruijn64Lookup[((bb ^ (bb - 1)) * debruijn64Seq) >> 58];
|
||||||
|
}
|
||||||
|
inline Index bitScanMS(u64 bb) {
|
||||||
|
assert(bb != 0);
|
||||||
|
|
||||||
|
bb |= bb >> 1;
|
||||||
|
bb |= bb >> 2;
|
||||||
|
bb |= bb >> 4;
|
||||||
|
bb |= bb >> 8;
|
||||||
|
bb |= bb >> 16;
|
||||||
|
bb |= bb >> 32;
|
||||||
|
|
||||||
|
return debruijn64Lookup[(bb * debruijn64Seq) >> 58];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_UTILS_H */
|
|
@ -0,0 +1,16 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Human Crafted Evaluation
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::NumericVector HCE(const std::vector<Board>& positions) {
|
||||||
|
// Iterate all positions and call the static board evaluation
|
||||||
|
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) {
|
||||||
|
return (double)pos.evaluate() / 100.0;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
PKG_CXXFLAGS += -I'../inst/include' -pthread -DRCPP_RCOUT
|
||||||
|
|
||||||
|
SOURCES = $(wildcard *.cpp) $(wildcard ../inst/include/SchachHoernchen/*.cpp)
|
||||||
|
OBJECTS = $(SOURCES:.cpp=.o)
|
|
@ -0,0 +1,280 @@
|
||||||
|
// Generated by using Rcpp::compileAttributes() -> do not edit by hand
|
||||||
|
// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
|
||||||
|
|
||||||
|
#include "../inst/include/Rchess.h"
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
using namespace Rcpp;
|
||||||
|
|
||||||
|
#ifdef RCPP_USE_GLOBAL_ROSTREAM
|
||||||
|
Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
|
||||||
|
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// HCE
|
||||||
|
Rcpp::NumericVector HCE(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_HCE(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(HCE(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isWhiteTurn
|
||||||
|
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isWhiteTurn(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isWhiteTurn(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isCheck
|
||||||
|
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isCheck(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isCheck(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isQuiet
|
||||||
|
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isQuiet(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isQuiet(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isTerminal
|
||||||
|
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isTerminal(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isTerminal(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// isInsufficient
|
||||||
|
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions);
|
||||||
|
RcppExport SEXP _Rchess_isInsufficient(SEXP positionsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(isInsufficient(positions));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// data_gen
|
||||||
|
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max, const bool quiet, const int min_ply_count, const bool white_only);
|
||||||
|
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP, SEXP quietSEXP, SEXP min_ply_countSEXP, SEXP white_onlySEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type file(fileSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type sample_size(sample_sizeSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type quiet(quietSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type min_ply_count(min_ply_countSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type white_only(white_onlySEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max, quiet, min_ply_count, white_only));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// eval_psqt
|
||||||
|
Rcpp::NumericVector eval_psqt(const std::vector<Board>& positions, const std::vector<Rcpp::NumericMatrix>& psqt, const bool pawn_structure, const bool eval_rooks, const bool eval_king);
|
||||||
|
RcppExport SEXP _Rchess_eval_psqt(SEXP positionsSEXP, SEXP psqtSEXP, SEXP pawn_structureSEXP, SEXP eval_rooksSEXP, SEXP eval_kingSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type positions(positionsSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Rcpp::NumericMatrix>& >::type psqt(psqtSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type pawn_structure(pawn_structureSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type eval_rooks(eval_rooksSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type eval_king(eval_kingSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(eval_psqt(positions, psqt, pawn_structure, eval_rooks, eval_king));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// fen2int
|
||||||
|
Rcpp::IntegerVector fen2int(const std::vector<Board>& boards);
|
||||||
|
RcppExport SEXP _Rchess_fen2int(SEXP boardsSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<Board>& >::type boards(boardsSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(fen2int(boards));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// read_cyclic
|
||||||
|
Rcpp::CharacterVector read_cyclic(const std::string& file, const int nrows, const int skip, const int start, const int line_len);
|
||||||
|
RcppExport SEXP _Rchess_read_cyclic(SEXP fileSEXP, SEXP nrowsSEXP, SEXP skipSEXP, SEXP startSEXP, SEXP line_lenSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type file(fileSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type nrows(nrowsSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type skip(skipSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type start(startSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const int >::type line_len(line_lenSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(read_cyclic(file, nrows, skip, start, line_len));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// sample_move
|
||||||
|
Move sample_move(const Board& pos);
|
||||||
|
RcppExport SEXP _Rchess_sample_move(SEXP posSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
Rcpp::traits::input_parameter< const Board& >::type pos(posSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(sample_move(pos));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// sample_fen
|
||||||
|
std::vector<Board> sample_fen(const unsigned nr, const unsigned min_depth, const unsigned max_depth);
|
||||||
|
RcppExport SEXP _Rchess_sample_fen(SEXP nrSEXP, SEXP min_depthSEXP, SEXP max_depthSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
Rcpp::traits::input_parameter< const unsigned >::type nr(nrSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const unsigned >::type min_depth(min_depthSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const unsigned >::type max_depth(max_depthSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(sample_fen(nr, min_depth, max_depth));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// board
|
||||||
|
Board board(const std::string& fen);
|
||||||
|
RcppExport SEXP _Rchess_board(SEXP fenSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(board(fen));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// print_board
|
||||||
|
void print_board(const std::string& fen);
|
||||||
|
RcppExport SEXP _Rchess_print_board(SEXP fenSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
|
||||||
|
print_board(fen);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// print_moves
|
||||||
|
void print_moves(const std::string& fen);
|
||||||
|
RcppExport SEXP _Rchess_print_moves(SEXP fenSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
|
||||||
|
print_moves(fen);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// print_bitboards
|
||||||
|
void print_bitboards(const std::string& fen);
|
||||||
|
RcppExport SEXP _Rchess_print_bitboards(SEXP fenSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
|
||||||
|
print_bitboards(fen);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// position
|
||||||
|
Board position(const Board& pos, const std::vector<std::string>& moves, const bool san);
|
||||||
|
RcppExport SEXP _Rchess_position(SEXP posSEXP, SEXP movesSEXP, SEXP sanSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const Board& >::type pos(posSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const std::vector<std::string>& >::type moves(movesSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const bool >::type san(sanSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(position(pos, moves, san));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// perft
|
||||||
|
void perft(const int depth);
|
||||||
|
RcppExport SEXP _Rchess_perft(SEXP depthSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::traits::input_parameter< const int >::type depth(depthSEXP);
|
||||||
|
perft(depth);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// go
|
||||||
|
Move go(const int depth);
|
||||||
|
RcppExport SEXP _Rchess_go(SEXP depthSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RObject rcpp_result_gen;
|
||||||
|
Rcpp::traits::input_parameter< const int >::type depth(depthSEXP);
|
||||||
|
rcpp_result_gen = Rcpp::wrap(go(depth));
|
||||||
|
return rcpp_result_gen;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// ucinewgame
|
||||||
|
void ucinewgame();
|
||||||
|
RcppExport SEXP _Rchess_ucinewgame() {
|
||||||
|
BEGIN_RCPP
|
||||||
|
ucinewgame();
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// onLoad
|
||||||
|
void onLoad(const std::string& libname, const std::string& pkgname);
|
||||||
|
RcppExport SEXP _Rchess_onLoad(SEXP libnameSEXP, SEXP pkgnameSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type libname(libnameSEXP);
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type pkgname(pkgnameSEXP);
|
||||||
|
onLoad(libname, pkgname);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
// onUnload
|
||||||
|
void onUnload(const std::string& libpath);
|
||||||
|
RcppExport SEXP _Rchess_onUnload(SEXP libpathSEXP) {
|
||||||
|
BEGIN_RCPP
|
||||||
|
Rcpp::RNGScope rcpp_rngScope_gen;
|
||||||
|
Rcpp::traits::input_parameter< const std::string& >::type libpath(libpathSEXP);
|
||||||
|
onUnload(libpath);
|
||||||
|
return R_NilValue;
|
||||||
|
END_RCPP
|
||||||
|
}
|
||||||
|
|
||||||
|
static const R_CallMethodDef CallEntries[] = {
|
||||||
|
{"_Rchess_HCE", (DL_FUNC) &_Rchess_HCE, 1},
|
||||||
|
{"_Rchess_isWhiteTurn", (DL_FUNC) &_Rchess_isWhiteTurn, 1},
|
||||||
|
{"_Rchess_isCheck", (DL_FUNC) &_Rchess_isCheck, 1},
|
||||||
|
{"_Rchess_isQuiet", (DL_FUNC) &_Rchess_isQuiet, 1},
|
||||||
|
{"_Rchess_isTerminal", (DL_FUNC) &_Rchess_isTerminal, 1},
|
||||||
|
{"_Rchess_isInsufficient", (DL_FUNC) &_Rchess_isInsufficient, 1},
|
||||||
|
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 7},
|
||||||
|
{"_Rchess_eval_psqt", (DL_FUNC) &_Rchess_eval_psqt, 5},
|
||||||
|
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
|
||||||
|
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
|
||||||
|
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},
|
||||||
|
{"_Rchess_sample_fen", (DL_FUNC) &_Rchess_sample_fen, 3},
|
||||||
|
{"_Rchess_board", (DL_FUNC) &_Rchess_board, 1},
|
||||||
|
{"_Rchess_print_board", (DL_FUNC) &_Rchess_print_board, 1},
|
||||||
|
{"_Rchess_print_moves", (DL_FUNC) &_Rchess_print_moves, 1},
|
||||||
|
{"_Rchess_print_bitboards", (DL_FUNC) &_Rchess_print_bitboards, 1},
|
||||||
|
{"_Rchess_position", (DL_FUNC) &_Rchess_position, 3},
|
||||||
|
{"_Rchess_perft", (DL_FUNC) &_Rchess_perft, 1},
|
||||||
|
{"_Rchess_go", (DL_FUNC) &_Rchess_go, 1},
|
||||||
|
{"_Rchess_ucinewgame", (DL_FUNC) &_Rchess_ucinewgame, 0},
|
||||||
|
{"_Rchess_onLoad", (DL_FUNC) &_Rchess_onLoad, 2},
|
||||||
|
{"_Rchess_onUnload", (DL_FUNC) &_Rchess_onUnload, 1},
|
||||||
|
{NULL, NULL, 0}
|
||||||
|
};
|
||||||
|
|
||||||
|
RcppExport void R_init_Rchess(DllInfo *dll) {
|
||||||
|
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
|
||||||
|
R_useDynamicSymbols(dll, FALSE);
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Given a FEN (position) determines if its whites turn
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isWhiteTurn(const std::vector<Board>& positions) {
|
||||||
|
// Iterate all positions and call the static board evaluation
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isWhiteTurn(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if current side to move is in check
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isCheck(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isCheck(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if the current position is a quiet position (no piece is attacked)
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isQuiet(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isQuiet(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if position is terminal
|
||||||
|
//'
|
||||||
|
//' Checks if the position is a terminal position, meaning if the game ended
|
||||||
|
//' by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
|
||||||
|
//' checked, therefore a seperate game history is required which the board
|
||||||
|
//' does NOT track.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isTerminal(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isTerminal(); }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Check if checkmate is possible by material on the board
|
||||||
|
//'
|
||||||
|
//' Checks if there is sufficient mating material on the board, meaning if it
|
||||||
|
//' possible for any side to deliver a check mate. More specifically, it
|
||||||
|
//' checks if the pieces on the board are KK, KNK or KBK.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::LogicalVector isInsufficient(const std::vector<Board>& positions) {
|
||||||
|
return Rcpp::LogicalVector(positions.begin(), positions.end(),
|
||||||
|
[](const Board& pos) { return pos.isInsufficient(); }
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,137 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <string>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
|
||||||
|
//' `gmlm_chess()` as data generator to provide random draws from a FEN data set
|
||||||
|
//' with scores filtered to be in in the range `score_min` to `score_max`.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(name = "data.gen", rng = true)]]
|
||||||
|
Rcpp::CharacterVector data_gen(
|
||||||
|
const std::string& file,
|
||||||
|
const int sample_size,
|
||||||
|
const float score_min = -5.0,
|
||||||
|
const float score_max = +5.0,
|
||||||
|
const bool quiet = false,
|
||||||
|
const int min_ply_count = 10,
|
||||||
|
const bool white_only = true
|
||||||
|
) {
|
||||||
|
// Check parames
|
||||||
|
if (sample_size < 1) {
|
||||||
|
Rcpp::stop("`sample_size` must be positive");
|
||||||
|
}
|
||||||
|
if (score_min >= score_max) {
|
||||||
|
Rcpp::stop("`score_min` must be strictly smaller than `score_max`");
|
||||||
|
}
|
||||||
|
|
||||||
|
// open FEN data set file
|
||||||
|
std::ifstream input(file);
|
||||||
|
if (!input) {
|
||||||
|
Rcpp::stop("Opening file '%s' failed", file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the read from stream position to a random line
|
||||||
|
input.seekg(0, std::ios::end);
|
||||||
|
unsigned long seek = unif_rand() * input.tellg();
|
||||||
|
input.seekg(seek);
|
||||||
|
// from random position set stream position to line start (if not over shot)
|
||||||
|
if (!input.eof()) {
|
||||||
|
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
|
||||||
|
}
|
||||||
|
// Ensure (in any case) we are at a legal position (recycle)
|
||||||
|
if (input.eof()) {
|
||||||
|
input.seekg(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate output sample
|
||||||
|
Rcpp::CharacterVector _sample(sample_size);
|
||||||
|
Rcpp::NumericVector _scores(sample_size);
|
||||||
|
|
||||||
|
// Read and filter lines from FEN data base file
|
||||||
|
std::string line, fen;
|
||||||
|
float score;
|
||||||
|
Board pos;
|
||||||
|
int sample_count = 0, retry_count = 0, reject_count = 0;
|
||||||
|
while (sample_count < sample_size) {
|
||||||
|
// Check for user interupt (that is, allows from `R` to interupt execution)
|
||||||
|
R_CheckUserInterrupt();
|
||||||
|
|
||||||
|
// Avoid infinite loop
|
||||||
|
if (reject_count > 1000 * sample_size) {
|
||||||
|
Rcpp::stop("Too many rejections, stop to avoid infinite loop");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read line, in case of failure retry from start of file (recycling)
|
||||||
|
if (!std::getline(input, line)) {
|
||||||
|
input.clear();
|
||||||
|
input.seekg(0);
|
||||||
|
if (!std::getline(input, line)) {
|
||||||
|
// another failur is fatal
|
||||||
|
Rcpp::stop("Recycline lines in file '%s' failed", file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for empty line, treated as a partial error which we retry a few times
|
||||||
|
if (line.empty()) {
|
||||||
|
if (++retry_count > 10) {
|
||||||
|
Rcpp::stop("Retry count exceeded after reading empty line in '%s'", file);
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split candidat line into FEN and score
|
||||||
|
std::stringstream candidat(line);
|
||||||
|
std::getline(candidat, fen, ';');
|
||||||
|
candidat >> score;
|
||||||
|
if (candidat.fail()) {
|
||||||
|
// If this failes, the FEN data base is ill formed!
|
||||||
|
Rcpp::stop("Ill formated FEN data base file '%s'", file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse FEN to filter only positions with white to move
|
||||||
|
bool parseError = false;
|
||||||
|
pos.init(fen, parseError);
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Retry count exceeded after illegal FEN '%s'", fen);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject / Filter samples
|
||||||
|
if (((int)pos.plyCount() < min_ply_count) // early positions
|
||||||
|
|| (white_only && (pos.sideToMove() == piece::black)) // white to move positions
|
||||||
|
|| (score < score_min || score_max <= score) // scores out of slice
|
||||||
|
|| (quiet && !pos.isQuiet())) // quiet positions
|
||||||
|
{
|
||||||
|
reject_count++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everythings succeeded and ge got an appropriate sample in requested range
|
||||||
|
_sample[sample_count] = fen;
|
||||||
|
_scores[sample_count] = score;
|
||||||
|
++sample_count;
|
||||||
|
|
||||||
|
// skip lines (ensures independent draws based on games being independent)
|
||||||
|
if (input.eof()) {
|
||||||
|
input.seekg(0);
|
||||||
|
}
|
||||||
|
for (int s = 0; s < 256; ++s) {
|
||||||
|
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
|
||||||
|
if (input.eof()) {
|
||||||
|
input.seekg(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set scores as attribute to position sample
|
||||||
|
_sample.attr("scores") = _scores;
|
||||||
|
|
||||||
|
return _sample;
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Human Crafted Evaluation
|
||||||
|
// [[Rcpp::export(name = "eval.psqt", rng = false)]]
|
||||||
|
Rcpp::NumericVector eval_psqt(
|
||||||
|
const std::vector<Board>& positions,
|
||||||
|
const std::vector<Rcpp::NumericMatrix>& psqt,
|
||||||
|
const bool pawn_structure = false,
|
||||||
|
const bool eval_rooks = false,
|
||||||
|
const bool eval_king = false
|
||||||
|
) {
|
||||||
|
// validate Piece Square Table count and sizes
|
||||||
|
if (psqt.size() != 6) {
|
||||||
|
Rcpp::stop("Expected exactly 6 PSQTs");
|
||||||
|
}
|
||||||
|
for (const auto table : psqt) {
|
||||||
|
if (table.nrow() != 8 || table.ncol() != 8) {
|
||||||
|
Rcpp::stop("PSQT table missmatch, all expected to be `8 x 8`");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create numeric vector by evaluating all positions
|
||||||
|
return Rcpp::NumericVector(positions.begin(), positions.end(),
|
||||||
|
[&psqt, pawn_structure, eval_rooks, eval_king](
|
||||||
|
const Board& pos
|
||||||
|
) {
|
||||||
|
// Index to color/piece mapping (more robust)
|
||||||
|
enum piece colorLoopup[2] = { white, black };
|
||||||
|
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
|
||||||
|
|
||||||
|
// Score is the "inner product" of the "one-hot encoded" position
|
||||||
|
// and the piece square tables (PSQT)
|
||||||
|
double whiteMaterial = 0.0, blackMaterial = 0.0;
|
||||||
|
for (int piece = 0; piece < 6; ++piece) {
|
||||||
|
u64 piece_bb = pos.bb(pieceLookup[piece]);
|
||||||
|
// First the White (positive) pieces
|
||||||
|
for (u64 bb = pos.bb(piece::white) & piece_bb; bb; bb &= bb - 1) {
|
||||||
|
// Get piece on bitboard index (Least Significant Bit)
|
||||||
|
int index = bitScanLS(bb);
|
||||||
|
// Transpose to align with PSQT memory layout
|
||||||
|
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||||
|
whiteMaterial += psqt[piece][index];
|
||||||
|
}
|
||||||
|
// Second the black (negative) pieces (with flipped Ranks)
|
||||||
|
for (u64 bb = pos.bb(piece::black) & piece_bb; bb; bb &= bb - 1) {
|
||||||
|
// Get fliped board index
|
||||||
|
int index = bitScanLS(bb);
|
||||||
|
// Transpose to align with PSQT memory layout and flip ranks
|
||||||
|
// convert from whites perspective to blacks persepective
|
||||||
|
index = ((index & 7) << 3) | (7 - ((index & 56) >> 3));
|
||||||
|
blackMaterial += psqt[piece][index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (whiteMaterial - blackMaterial) / 100.0;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
devtools::load_all()
|
||||||
|
save_point <- sort(list.files(
|
||||||
|
"~/Work/tensorPredictors/dataAnalysis/chess/",
|
||||||
|
pattern = "save_point.*\\.Rdata",
|
||||||
|
full.names = TRUE
|
||||||
|
), decreasing = TRUE)[[1]]
|
||||||
|
load(save_point)
|
||||||
|
|
||||||
|
psqt <- Map(function(parts) matrix(rowSums(kronecker(parts[[2]], parts[[1]])), 8, 8), betas)
|
||||||
|
psqt <- Map(`-`, psqt[1:6], Map(function(table) table[8:1, ], psqt[7:12]))
|
||||||
|
|
||||||
|
eval.psqt("startpos", psqt)
|
||||||
|
|
||||||
|
*/
|
|
@ -0,0 +1,58 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/types.h"
|
||||||
|
#include "SchachHoernchen/utils.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
|
||||||
|
// Initialize empty chess board as a 3D binary tensor
|
||||||
|
Rcpp::IntegerVector bitboards(8 * 8 * 12 * (int)boards.size());
|
||||||
|
// Set dimension and dimension names (required this way since `Rcpp::Dimension`
|
||||||
|
// does _not_ support 4D arrays)
|
||||||
|
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
|
||||||
|
bitboards.attr("dim") = dims;
|
||||||
|
bitboards.attr("dimnames") = Rcpp::List::create(
|
||||||
|
Rcpp::Named("rank") = Rcpp::CharacterVector::create(
|
||||||
|
"8", "7", "6", "5", "4", "3", "2", "1"
|
||||||
|
),
|
||||||
|
Rcpp::Named("file") = Rcpp::CharacterVector::create(
|
||||||
|
"a", "b", "c", "d", "e", "f", "g", "h"
|
||||||
|
),
|
||||||
|
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
|
||||||
|
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
|
||||||
|
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
|
||||||
|
),
|
||||||
|
R_NilValue
|
||||||
|
);
|
||||||
|
|
||||||
|
// Index to color/piece mapping (more robust)
|
||||||
|
enum piece colorLoopup[2] = { white, black };
|
||||||
|
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
|
||||||
|
|
||||||
|
// Set for every piece the corresponding pit positions, note the
|
||||||
|
// "transposition" of the indexing from SchachHoernchen's binary indexing
|
||||||
|
// scheme to the 3D array indexing of ranks/files.
|
||||||
|
for (int i = 0; i < boards.size(); ++i) {
|
||||||
|
const Board& pos = boards[i];
|
||||||
|
for (int color = 0; color < 2; ++color) {
|
||||||
|
for (int piece = 0; piece < 6; ++piece) {
|
||||||
|
int slice = 6 * color + piece;
|
||||||
|
u64 bb = pos.bb(colorLoopup[color]) & pos.bb(pieceLookup[piece]);
|
||||||
|
for (; bb; bb &= bb - 1) {
|
||||||
|
// Get BitBoard index
|
||||||
|
int index = bitScanLS(bb);
|
||||||
|
// Transpose to align with printing as a Chess Board
|
||||||
|
index = ((index & 7) << 3) | ((index & 56) >> 3);
|
||||||
|
// Flip black to move positions to whites point of view
|
||||||
|
index ^= pos.isWhiteTurn() ? 0 : 7;
|
||||||
|
bitboards[768 * i + 64 * slice + index] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bitboards;
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <string>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
//' Reads lines from a text file with recycling.
|
||||||
|
//'
|
||||||
|
// [[Rcpp::export(name = "read.cyclic", rng = false)]]
|
||||||
|
Rcpp::CharacterVector read_cyclic(
|
||||||
|
const std::string& file,
|
||||||
|
const int nrows = 1000,
|
||||||
|
const int skip = 100,
|
||||||
|
const int start = 1,
|
||||||
|
const int line_len = 64
|
||||||
|
) {
|
||||||
|
// `unsigned` is not "properly" checked by `Rcpp` to _not_ be negative.
|
||||||
|
// An implicit cast to unsigned makes negative integers gigantic, not wanted!
|
||||||
|
if (skip < 0) {
|
||||||
|
Rcpp::stop("`skip` must be non-negative.");
|
||||||
|
}
|
||||||
|
if (nrows < 1 || start < 1 || line_len < 1) {
|
||||||
|
Rcpp::stop("`start`, `nrows` and `line_len` must be positive.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open input file
|
||||||
|
std::ifstream input(file);
|
||||||
|
if (!input) {
|
||||||
|
Rcpp::stop("Opening file '%s' failed", file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different versions of start positions, if the start position is
|
||||||
|
// large, we guess a value to skip via a simple heuristic and go with it
|
||||||
|
if (1000 < start) {
|
||||||
|
// Skip (approx) `start` lines
|
||||||
|
input.seekg(0, std::ios::end); // get to end of file
|
||||||
|
unsigned long size = static_cast<unsigned long>(input.tellg());
|
||||||
|
unsigned long seek = static_cast<unsigned long>(line_len) * (
|
||||||
|
static_cast<unsigned long>(start) - 1UL
|
||||||
|
);
|
||||||
|
seek = seek % size;
|
||||||
|
// Now seek to the approx line nr. with recycling
|
||||||
|
input.seekg(seek);
|
||||||
|
// read till end of line to have a proper start of line position
|
||||||
|
if (!input.eof()) {
|
||||||
|
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
|
||||||
|
}
|
||||||
|
// in the occastional case of ending at a last line (recycle)
|
||||||
|
if (input.eof()) { input.seekg(0); }
|
||||||
|
} else {
|
||||||
|
// Skip (exactly) `start` lines
|
||||||
|
for (int line_nr = 1; line_nr < start; ++line_nr) {
|
||||||
|
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
|
||||||
|
// In case of reaching the end of file, restart (recycle)
|
||||||
|
if (input.eof()) { input.seekg(0); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create character vector for output lines
|
||||||
|
Rcpp::CharacterVector lines(nrows);
|
||||||
|
|
||||||
|
// Read one line and skip multiple till `nrows` lines are included
|
||||||
|
std::string line;
|
||||||
|
for (int i = 0; i < nrows; ++i) {
|
||||||
|
// Read one line
|
||||||
|
if (std::getline(input, line)) {
|
||||||
|
lines[i] = line;
|
||||||
|
} else {
|
||||||
|
// retry from start of file, may occure in last empty line of file
|
||||||
|
input.clear();
|
||||||
|
input.seekg(0);
|
||||||
|
if (std::getline(input, line)) {
|
||||||
|
lines[i] = line;
|
||||||
|
} else {
|
||||||
|
// Another failur is fatal
|
||||||
|
Rcpp::stop("Recycling lines in file '%s' failed", file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recycle if at end of file (always check)
|
||||||
|
if (input.eof()) { input.seekg(0); }
|
||||||
|
|
||||||
|
// skip lines (with recycling)
|
||||||
|
for (int s = 0; s < skip; ++s) {
|
||||||
|
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
|
||||||
|
if (input.eof()) { input.seekg(0); }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lines;
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
#include <vector>
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
|
||||||
|
//' Samples a legal move from a given position
|
||||||
|
// [[Rcpp::export("sample.move", rng = true)]]
|
||||||
|
Move sample_move(const Board& pos) {
|
||||||
|
// RNG for continuous uniform X ~ U[0, 1]
|
||||||
|
auto runif = Rcpp::stats::UnifGenerator(0, 1);
|
||||||
|
// RNG for discrete uniform X ~ U[0; max - 1]
|
||||||
|
auto rindex = [&runif](const std::size_t max) {
|
||||||
|
// sample random index untill in range [0; max - 1]
|
||||||
|
unsigned index;
|
||||||
|
do {
|
||||||
|
index = static_cast<unsigned>(runif() * static_cast<double>(max));
|
||||||
|
} while (index >= max);
|
||||||
|
return index;
|
||||||
|
};
|
||||||
|
|
||||||
|
// generate all legal moves
|
||||||
|
MoveList moves;
|
||||||
|
pos.moves(moves);
|
||||||
|
|
||||||
|
// check if there are any moves to sample from
|
||||||
|
if (moves.empty()) {
|
||||||
|
Rcpp::stop("Attempt to same legal move from terminal node");
|
||||||
|
}
|
||||||
|
|
||||||
|
return moves[rindex(moves.size())];
|
||||||
|
}
|
||||||
|
|
||||||
|
//' Samples a random FEN (position) by applying `ply` random moves to the start
|
||||||
|
//' position.
|
||||||
|
//'
|
||||||
|
//' @param nr number of positions to sample
|
||||||
|
//' @param min_depth minimum number of random ply's to generate random positions
|
||||||
|
//' @param max_depth maximum number of random ply's to generate random positions
|
||||||
|
// [[Rcpp::export("sample.fen", rng = true)]]
|
||||||
|
std::vector<Board> sample_fen(const unsigned nr,
|
||||||
|
const unsigned min_depth = 4, const unsigned max_depth = 20
|
||||||
|
) {
|
||||||
|
// Parameter validation
|
||||||
|
if (min_depth > max_depth) {
|
||||||
|
Rcpp::stop("max_depth must be bigger equal than min_depth");
|
||||||
|
}
|
||||||
|
if (128 < max_depth) {
|
||||||
|
Rcpp::stop("max_depth exceeded maximum value 128");
|
||||||
|
}
|
||||||
|
|
||||||
|
// RNG for continuous uniform X ~ U[0, 1]
|
||||||
|
auto runif = Rcpp::stats::UnifGenerator(0, 1);
|
||||||
|
// RNG for discrete uniform X ~ U[0; max - 1]
|
||||||
|
auto rindex = [&runif](const std::size_t max) {
|
||||||
|
// Check if max is bigger than zero, cause we can not sample from the
|
||||||
|
// empty set
|
||||||
|
if (max == 0) {
|
||||||
|
Rcpp::stop("Attempt to sample random index < 0.");
|
||||||
|
}
|
||||||
|
// sample random index untill in range [0; max - 1]
|
||||||
|
unsigned index = max;
|
||||||
|
while (index >= max) {
|
||||||
|
index = static_cast<unsigned>(runif() * static_cast<double>(max));
|
||||||
|
}
|
||||||
|
return index;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Setup response vector
|
||||||
|
std::vector<Board> fens;
|
||||||
|
fens.reserve(nr);
|
||||||
|
|
||||||
|
// Sample FENs
|
||||||
|
MoveList moves;
|
||||||
|
for (unsigned i = 0; i < nr; ++i) {
|
||||||
|
Board pos; // start position
|
||||||
|
unsigned depth = static_cast<unsigned>(runif() * (max_depth - min_depth));
|
||||||
|
depth += min_depth;
|
||||||
|
for (unsigned ply = 0; ply < depth; ++ply) {
|
||||||
|
moves.clear();
|
||||||
|
pos.moves(moves);
|
||||||
|
if (moves.size()) {
|
||||||
|
pos.make(moves[rindex(moves.size())]);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fens.push_back(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
return fens;
|
||||||
|
}
|
|
@ -0,0 +1,146 @@
|
||||||
|
// UCI (Univeral Chess Interface) `R` binding to the `SchachHoernchen` engine
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iterator>
|
||||||
|
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/Move.h"
|
||||||
|
#include "SchachHoernchen/Board.h"
|
||||||
|
#include "SchachHoernchen/uci.h"
|
||||||
|
#include "SchachHoernchen/search.h"
|
||||||
|
|
||||||
|
namespace UCI_State {
|
||||||
|
std::vector<Board> game(1);
|
||||||
|
};
|
||||||
|
|
||||||
|
//' Converts a FEN string to a Board (position) or return the current internal state
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Board board(const std::string& fen = "") {
|
||||||
|
if (fen == "") {
|
||||||
|
return UCI_State::game.back();
|
||||||
|
} else if (fen == "startpos") {
|
||||||
|
return Board();
|
||||||
|
}
|
||||||
|
Board pos;
|
||||||
|
bool parseError = false;
|
||||||
|
pos.init(fen, parseError);
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Parse parsing FEN");
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(name = "print.board", rng = false)]]
|
||||||
|
void print_board(const std::string& fen = "") {
|
||||||
|
UCI::printBoard(board(fen));
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(name = "print.moves", rng = false)]]
|
||||||
|
void print_moves(const std::string& fen = "") {
|
||||||
|
UCI::printMoves(board(fen));
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(name = "print.bitboards", rng = false)]]
|
||||||
|
void print_bitboards(const std::string& fen = "") {
|
||||||
|
UCI::printBitBoards(board(fen));
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Board position(
|
||||||
|
const Board& pos,
|
||||||
|
const std::vector<std::string>& moves,
|
||||||
|
const bool san = false
|
||||||
|
) {
|
||||||
|
// Build UCI command (without "position")
|
||||||
|
std::stringstream cmd;
|
||||||
|
cmd << "fen " << pos.fen() << " ";
|
||||||
|
if (moves.size()) {
|
||||||
|
cmd << "moves ";
|
||||||
|
std::copy(moves.begin(), moves.end(), std::ostream_iterator<std::string>(cmd, " "));
|
||||||
|
}
|
||||||
|
|
||||||
|
// set UCI internal flag to interprate move input in from-to format or SAN
|
||||||
|
UCI::readSAN = san;
|
||||||
|
|
||||||
|
// and invoke UCI position command handler on the internal game state
|
||||||
|
bool parseError = false;
|
||||||
|
UCI::position(UCI_State::game, cmd, parseError);
|
||||||
|
if (parseError) {
|
||||||
|
Rcpp::stop("Parse Error");
|
||||||
|
}
|
||||||
|
|
||||||
|
return UCI_State::game.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
void perft(const int depth = 6) {
|
||||||
|
// Enforce a depth limit, this is very restrictive but we only want to
|
||||||
|
// use it as a toolbox and _not_ as a strong chess engine
|
||||||
|
if (8 < depth) {
|
||||||
|
Rcpp::stop("In `R` search is limited to depth 8");
|
||||||
|
} else if (depth <= 0) {
|
||||||
|
cout << "Nodes searched: 0" << std::endl;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current set position
|
||||||
|
Board pos(UCI_State::game.back());
|
||||||
|
|
||||||
|
// Get all legal moves
|
||||||
|
MoveList moves;
|
||||||
|
pos.moves(moves);
|
||||||
|
|
||||||
|
// Setup counter for total nr. of moves
|
||||||
|
Index totalCount = 0;
|
||||||
|
|
||||||
|
Board copy(pos);
|
||||||
|
for (Move move : moves) {
|
||||||
|
// continue traversing
|
||||||
|
pos.make(move);
|
||||||
|
Index nodeCount = Search::perft_subroutine(pos, depth - 1);
|
||||||
|
totalCount += nodeCount;
|
||||||
|
pos = copy; // unmake move
|
||||||
|
|
||||||
|
// report moves of node
|
||||||
|
cout << move << ": " << nodeCount << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << std::endl << "Nodes searched: " << totalCount << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
Move go(
|
||||||
|
const int depth = 6
|
||||||
|
) {
|
||||||
|
// Enforce a depth limit, this is very restrictive but we only want to
|
||||||
|
// use it as a toolbox and _not_ as a strong chess engine
|
||||||
|
if (8 < depth) {
|
||||||
|
Rcpp::stop("In `R` search is limited to depth 8");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup search configuration
|
||||||
|
Search::State config;
|
||||||
|
config.depth = depth < 1 ? 1 : depth;
|
||||||
|
|
||||||
|
// sets worker thread stop condition to false (before dispatch) which in this
|
||||||
|
// context is the main thread. Only need to ensure its running.
|
||||||
|
Search::isRunning.store(true, std::memory_order_release);
|
||||||
|
|
||||||
|
// Construct a search object
|
||||||
|
Search::PVS<Board> search(UCI_State::game, config);
|
||||||
|
|
||||||
|
// and start the search
|
||||||
|
search();
|
||||||
|
|
||||||
|
// return best move
|
||||||
|
return search.bestMove();
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(rng = false)]]
|
||||||
|
void ucinewgame() {
|
||||||
|
Search::newgame();
|
||||||
|
UCI_State::game.clear();
|
||||||
|
UCI_State::game.emplace_back();
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Startup and unload routines for initialization and cleanup, see: `R/zzz.R`
|
||||||
|
|
||||||
|
#include <Rcpp.h>
|
||||||
|
|
||||||
|
#include "SchachHoernchen/search.h"
|
||||||
|
|
||||||
|
// [[Rcpp::export(".onLoad")]]
|
||||||
|
void onLoad(const std::string& libname, const std::string& pkgname) {
|
||||||
|
// Initialize search (Transposition Table)
|
||||||
|
Search::init();
|
||||||
|
// Report search initialization
|
||||||
|
Rcpp::Rcout << "info string search initialized" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// [[Rcpp::export(".onUnload")]]
|
||||||
|
void onUnload(const std::string& libpath) {
|
||||||
|
// Cleanup any outstanding or running/finished worker tasks
|
||||||
|
Search::stop();
|
||||||
|
// Report shutdown (stopping any remaining searches)
|
||||||
|
Rcpp::Rcout << "info string shutdown" << std::endl;
|
||||||
|
}
|
|
@ -0,0 +1,175 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
library(Rchess)
|
||||||
|
library(mgcv) # for `gam()` (Generalized Additive Model)
|
||||||
|
|
||||||
|
source("./gmlm_chess.R")
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
### Fitting the GMLM mixture model ###
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# Data set file name of chess positions with Stockfish [https://stockfishchess.org]
|
||||||
|
# evaluation scores (downloaded and processed by `./preprocessing.sh` from the
|
||||||
|
# lichess data base [https://database.lichess.org/])
|
||||||
|
data_set <- "lichess_db_standard_rated_2023-11.fen"
|
||||||
|
|
||||||
|
# Function to draw samples `X` form the chess position `data_set` conditioned on
|
||||||
|
# `Y` (position scores) to be in the interval `score_min` to `score_max`.
|
||||||
|
data_gen <- function(batch_size, score_min, score_max) {
|
||||||
|
Rchess::fen2int(Rchess::data.gen(data_set, batch_size, score_min, score_max, quiet = TRUE))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Invoke specialized GMLM optimization routine for chess data
|
||||||
|
fit.gmlm <- gmlm_chess(data_gen)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
### Reduction Interpretation and Validation ###
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
# load last save point (includes reduction as `betas`)
|
||||||
|
save_point <- sort(list.files(
|
||||||
|
".",
|
||||||
|
pattern = "save_point_[0-9]*\\.Rdata",
|
||||||
|
full.names = TRUE
|
||||||
|
), decreasing = TRUE)[[1]]
|
||||||
|
load(save_point)
|
||||||
|
|
||||||
|
|
||||||
|
### Construct PSQT (Piece SQuare Tables) from reduction `betas`
|
||||||
|
sample_size <- 100000
|
||||||
|
# Sample a new position data set for fitting a linear model to conbine different
|
||||||
|
# reduction directions into a per piece PSQT matrix
|
||||||
|
fens <- Rchess::data.gen(data_set, sample_size, -20, 20, quiet = TRUE)
|
||||||
|
# extract stockfish (non-static) position evaluation
|
||||||
|
y <- attr(fens, "scores")
|
||||||
|
|
||||||
|
# Convert position into "One-Hot Encoded" / "Bit Board" tensor
|
||||||
|
X <- Rchess::fen2int(fens)
|
||||||
|
# Compute reduction
|
||||||
|
reducedX <- Reduce(rbind, Map(function(piece) {
|
||||||
|
# "condition" on piece, that is to extract the current mixture component
|
||||||
|
X <- X[, , piece, ]
|
||||||
|
# reduce mixture component
|
||||||
|
mlm(X - as.vector(rowMeans(X, dims = 2)), betas[[piece]], transposed = TRUE)
|
||||||
|
}, 1:12))
|
||||||
|
# Convert memory layout to contain vectorized observations in rows
|
||||||
|
reducedX <- t(`dim<-`(reducedX, c(48, sample_size)))
|
||||||
|
# set names for coefficient extraction from linear fit
|
||||||
|
colnames(reducedX) <- as.vector(outer(
|
||||||
|
unlist(strsplit("PNBRQKpnbrqk", "")), c(1, "yl", "yu", "y.2"), paste, sep = "."
|
||||||
|
))
|
||||||
|
|
||||||
|
# Estimate PSQT linear combination weights from reduced sample (exclude dead
|
||||||
|
# draw positions, that is "score = 0". This are approx 5% of all positions)
|
||||||
|
fit <- lm(y ~ ., data = data.frame(y = y, reducedX), subset = y != 0.0)
|
||||||
|
summary(fit)
|
||||||
|
# Translate reduction with weighting estimate into PSQTs
|
||||||
|
psqt <- Map(function(piece) {
|
||||||
|
# reduction column names corresponding to the current white piece (upper case)
|
||||||
|
piece <- toupper(piece)
|
||||||
|
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
|
||||||
|
# Whites PSQT
|
||||||
|
psqt_white <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
|
||||||
|
dim(psqt_white) <- c(8, 8)
|
||||||
|
# the same for black
|
||||||
|
piece <- tolower(piece)
|
||||||
|
col_names <- paste(piece, c(1, "yl", "yu", "y.2"), sep = ".")
|
||||||
|
psqt_black <- do.call(kronecker, rev(betas[[piece]])) %*% coef(fit)[col_names]
|
||||||
|
dim(psqt_black) <- c(8, 8)
|
||||||
|
# Combine into shared PSQT from whites point of view
|
||||||
|
psqt_white - psqt_black[8:1, ]
|
||||||
|
}, c("P", "N", "B", "R", "Q", "K"))
|
||||||
|
# finish by enforcing the pawn constraint (irrelevant for validation, the
|
||||||
|
# corresponding values in an encoded position is always zero)
|
||||||
|
psqt[["P"]][c(1, 8), ] <- 0
|
||||||
|
|
||||||
|
### Validation by GAM fitted on reduced data
|
||||||
|
formula <- as.formula(paste("y ~ ", paste("s(", colnames(reducedX), ")", collapse = "+")))
|
||||||
|
fit.gam <- mgcv::gam(formula, data = data.frame(y = y, reducedX), subset = y != 0.0)
|
||||||
|
summary(fit.gam)
|
||||||
|
|
||||||
|
# compair estimates with mean as baseline and static human crafted evaluation (HCE)
|
||||||
|
rmse.base <- sqrt(mean((mean(y) - y)^2))
|
||||||
|
y.hce <- Rchess::HCE(fens)
|
||||||
|
rmse.hce <- sqrt(mean((y.hce - y)^2))
|
||||||
|
y.hat <- predict(fit.gam, newdata = data.frame(reducedX))
|
||||||
|
rmse.hat <- sqrt(mean((y.hat - y)^2))
|
||||||
|
|
||||||
|
# Also extract R^2 (eval by hand or get from models)
|
||||||
|
(r.sq.lm <- summary(fit)$r.squared)
|
||||||
|
(r.sq.gam <- summary(fit.gam)$r.sq)
|
||||||
|
(r.sq.hce <- 1 - (rmse.hce / rmse.base)^2)
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
### Generate LaTeX PSQT plot ###
|
||||||
|
################################################################################
|
||||||
|
if (FALSE) {
|
||||||
|
|
||||||
|
sink("psqt.tex")
|
||||||
|
|
||||||
|
cat("% Authomatically generated by `dataAnalysis/chess.R`
|
||||||
|
\\documentclass{standalone}
|
||||||
|
|
||||||
|
\\usepackage[LSB, T1]{fontenc}
|
||||||
|
\\usepackage{chessboard}
|
||||||
|
\\usepackage{skak}
|
||||||
|
\\usepackage{tikz}
|
||||||
|
\\usepackage{amsmath}
|
||||||
|
\\usepackage{xcolor}
|
||||||
|
|
||||||
|
\\setboardfontencoding{LSB}
|
||||||
|
|
||||||
|
\\setchessboard{linewidth = 0.1em, showmover = false, smallboard}
|
||||||
|
|
||||||
|
")
|
||||||
|
|
||||||
|
cat(paste0("\\definecolor{col", 1:128, "}{HTML}{",
|
||||||
|
mapply(`[`, strsplit(hcl.colors(128, "Blue-Red 3", rev = TRUE), "#"), 2),
|
||||||
|
"}"
|
||||||
|
))
|
||||||
|
|
||||||
|
cat("
|
||||||
|
|
||||||
|
\\begin{document}
|
||||||
|
\\begin{tikzpicture}
|
||||||
|
|
||||||
|
\\coordinate (pawn) at (0, 0);
|
||||||
|
\\coordinate (knight) at (5, 0);
|
||||||
|
\\coordinate (bishop) at (10, 0);
|
||||||
|
\\coordinate (rook) at (0, -5.2);
|
||||||
|
\\coordinate (queen) at (5, -5.2);
|
||||||
|
\\coordinate (king) at (10, -5.2);
|
||||||
|
|
||||||
|
")
|
||||||
|
|
||||||
|
local({
|
||||||
|
zlim <- c(-1, 1) * max(abs(unlist(psqt, use.names = FALSE)))
|
||||||
|
breaks <- seq(zlim[1], zlim[2], len = 129)
|
||||||
|
|
||||||
|
pieces <- c("pawn", "knight", "bishop", "rook", "queen", "king")
|
||||||
|
for (i in seq_along(psqt)) {
|
||||||
|
cat(paste0("\\node (", pieces[i], ") at (", pieces[i], ") {\\chessboard[", paste0(
|
||||||
|
"color=col", as.integer(cut(psqt[[i]], breaks)),
|
||||||
|
",colorbackfield=", outer(8:1, letters[1:8], function(r, f) paste0(f, r)),
|
||||||
|
collapse=","
|
||||||
|
), "]};\n"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
cat("
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (pawn.north) {Pawn};
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (knight.north) {Knight};
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (bishop.north) {Bishop};
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (rook.north) {Rook};
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (queen.north) {Queen};
|
||||||
|
\\node[anchor = north, yshift = -0.4em] at (king.north) {King};
|
||||||
|
|
||||||
|
\\end{tikzpicture}
|
||||||
|
\\end{document}
|
||||||
|
")
|
||||||
|
|
||||||
|
sink()
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,247 @@
|
||||||
|
#' Specialized version of `gmlm_ising()`.
|
||||||
|
#'
|
||||||
|
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
|
||||||
|
#' gradient descent version of RMSprop instead of classic gradient descent.
|
||||||
|
#' Other differences are puerly of technical nature.
|
||||||
|
#'
|
||||||
|
#' @param data_gen data generator, samples from the data set conditioned on a
|
||||||
|
#' slice value `y.min` to `y.max`. Function signature
|
||||||
|
#' `function(batch.size, y.min, y.max)` with return value `X`, a
|
||||||
|
#' `8 x 8 x 12 x batch.size` 4D array.
|
||||||
|
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
|
||||||
|
#' @param score_breaks numeric vector of two or more unique cut points, the cut
|
||||||
|
#' points are the interval bounds specifying the slices of `y`.
|
||||||
|
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
|
||||||
|
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
|
||||||
|
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
|
||||||
|
#' every slice. The batch size of the for every iteration is then equal to
|
||||||
|
#' `slice_size * (length(score_breaks) - 1L)`.
|
||||||
|
#' @param max_iter maximum number of iterations for gradient optimization
|
||||||
|
#' @param patience integer, break condition parameter. If the approximated loss
|
||||||
|
#' doesn't improve over `patience` iterations, then stop.
|
||||||
|
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
|
||||||
|
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
|
||||||
|
#' the parameter update rule of RMSprop
|
||||||
|
#' @param save_point character, file name pattern for storing and retrieving
|
||||||
|
#' optimization save points. Those save points allow to stop the method and
|
||||||
|
#' resume optimization later from the last save point.
|
||||||
|
#'
|
||||||
|
gmlm_chess <- function(
|
||||||
|
data_gen,
|
||||||
|
fun_y = function(y) { `dim<-`(t(outer(y, c(0, 1, 1, 2), `^`)), c(2, 2, length(y))) },
|
||||||
|
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
|
||||||
|
nr_threads = 8L,
|
||||||
|
mcmc_samples = 10000L,
|
||||||
|
slice_size = 512L,
|
||||||
|
max_iter = 10000L,
|
||||||
|
patience = 10L,
|
||||||
|
step_size = 1e-2,
|
||||||
|
eps = sqrt(.Machine$double.eps),
|
||||||
|
save_point = "save_point_%s.Rdata"
|
||||||
|
) {
|
||||||
|
# build intervals from score break points
|
||||||
|
score_breaks <- sort(score_breaks)
|
||||||
|
score_min <- head(score_breaks, -1)
|
||||||
|
score_max <- tail(score_breaks, -1)
|
||||||
|
score_means <- (score_min + score_max) / 2
|
||||||
|
|
||||||
|
# Piece index lookup "table" by piece symbol
|
||||||
|
pieces <- `names<-`(1:12, unlist(strsplit("PNBRQKpnbrqk", "")))
|
||||||
|
|
||||||
|
# Build constraints for every mixture component, that is, for every piece
|
||||||
|
pawn_const <- which(as.logical(tcrossprod(.row(c(8, 8)) %in% c(1, 8))))
|
||||||
|
# King and Queen constraints (by queens its just an approx)
|
||||||
|
KQ_const <- which(!diag(64))
|
||||||
|
|
||||||
|
# Check if there is a save point (load from save)
|
||||||
|
load_point <- if (is.character(save_point)) {
|
||||||
|
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
|
||||||
|
} else {
|
||||||
|
character(0)
|
||||||
|
}
|
||||||
|
# It a load point is found, resume from save point, otherwise initialize
|
||||||
|
if (length(load_point)) {
|
||||||
|
load_point <- load_point[[1]]
|
||||||
|
cat(sprintf("Resuming from save point '%s'\n", load_point),
|
||||||
|
"(to restart delete/rename the save points)\n")
|
||||||
|
load(load_point)
|
||||||
|
# Fix `iter`, saved after increment
|
||||||
|
iter <- iter - 1L
|
||||||
|
} else {
|
||||||
|
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
|
||||||
|
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
|
||||||
|
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
|
||||||
|
F <- fun_y(rep(score_means, each = slice_size))
|
||||||
|
|
||||||
|
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
|
||||||
|
dimX <- c(8L, 8L) # for every mixture component
|
||||||
|
dimF <- dim(F)[1:2] # also per mixture component
|
||||||
|
|
||||||
|
# Initialize `betas` for every mixture component
|
||||||
|
betas <- Map(function(piece) {
|
||||||
|
gmlm_tensor_normal(X[, , piece, ], F)$betas
|
||||||
|
}, pieces)
|
||||||
|
|
||||||
|
# and initial values for `Omegas`, based on the same first "big" sample
|
||||||
|
Omegas <- Map(function(piece) {
|
||||||
|
X <- X[, , piece, ]
|
||||||
|
Map(function(mode) {
|
||||||
|
n <- prod(dim(X)[-mode])
|
||||||
|
prob2 <- mcrossprod(X, mode = mode) / n
|
||||||
|
prob2[prob2 == 0] <- 1 / n
|
||||||
|
prob2[prob2 == 1] <- (n - 1) / n
|
||||||
|
prob1 <- diag(prob2)
|
||||||
|
`prob1^2` <- outer(prob1, prob1)
|
||||||
|
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
|
||||||
|
}, 1:2)
|
||||||
|
}, pieces)
|
||||||
|
Omegas[[pieces["P"]]][[1]][c(1, 8), ] <- 0
|
||||||
|
Omegas[[pieces["p"]]][[1]][c(1, 8), ] <- 0
|
||||||
|
|
||||||
|
# Initial sample `(X, F)` no longer needed, remove them
|
||||||
|
rm(X, F)
|
||||||
|
|
||||||
|
# Initialize gradients and aggregated mean squared gradients
|
||||||
|
grad2_betas <- Map(function(params) Map(array, 0, Map(dim, params)), betas)
|
||||||
|
grad2_Omegas <- Map(function(params) Map(array, 0, Map(dim, params)), Omegas)
|
||||||
|
|
||||||
|
# initialize optimization tracker for break condition
|
||||||
|
last_loss <- best_loss <- Inf
|
||||||
|
non_improving <- 0L
|
||||||
|
iter <- 0L
|
||||||
|
}
|
||||||
|
|
||||||
|
# main optimization loop
|
||||||
|
while ((iter <- iter + 1L) <= max_iter) {
|
||||||
|
|
||||||
|
# At beginning of every iteration, store current state in a save point.
|
||||||
|
# This allows to resume optimization from the last save point.
|
||||||
|
if (is.character(save_point)) {
|
||||||
|
suspendInterrupts(save(
|
||||||
|
dimX, dimF,
|
||||||
|
betas, Omegas,
|
||||||
|
grad2_betas, grad2_Omegas,
|
||||||
|
last_loss, best_loss, non_improving, iter,
|
||||||
|
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
|
||||||
|
}
|
||||||
|
|
||||||
|
# start timing for this iteration (this is precise enough)
|
||||||
|
start_time <- proc.time()[["elapsed"]]
|
||||||
|
|
||||||
|
# Full Omega(s) for every piece mixture component
|
||||||
|
Omega <- Map(function(Omegas) {
|
||||||
|
kronecker(Omegas[[2]], Omegas[[1]])
|
||||||
|
}, Omegas)
|
||||||
|
Omega[[pieces["P"]]][pawn_const] <- 0
|
||||||
|
Omega[[pieces["p"]]][pawn_const] <- 0
|
||||||
|
Omega[[pieces["K"]]][KQ_const] <- 0
|
||||||
|
Omega[[pieces["k"]]][KQ_const] <- 0
|
||||||
|
Omega[[pieces["Q"]]][KQ_const] <- 0
|
||||||
|
Omega[[pieces["q"]]][KQ_const] <- 0
|
||||||
|
|
||||||
|
# Gradient and negative log-likelihood approximation
|
||||||
|
loss <- 0 # neg. log-likelihood
|
||||||
|
grad_betas <- Map(function(piece) Map(matrix, 0, dimX, dimF), pieces) # grads for betas
|
||||||
|
R2 <- Map(function(piece) array(0, dim = c(dimX, dimX)), pieces) # residuals
|
||||||
|
|
||||||
|
# for every score slice
|
||||||
|
for (slice in seq_along(score_means)) {
|
||||||
|
# function of `y` being the score slice mean (only 3D, same for all obs.)
|
||||||
|
F <- `dim<-`(fun_y(score_means[slice]), dimF)
|
||||||
|
|
||||||
|
# compute parameters of (slice) conditional Ising model
|
||||||
|
params <- Map(function(betas, Omega) {
|
||||||
|
`diag<-`(Omega, as.vector(mlm(F, betas)))
|
||||||
|
}, betas, Omega)
|
||||||
|
|
||||||
|
# second moment of `X_{,,piece} | Y = score_means[slice]` for every piece
|
||||||
|
m2 <- Map(function(param) {
|
||||||
|
ising_m2(param, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
|
||||||
|
}, params)
|
||||||
|
|
||||||
|
# Draw a new sample
|
||||||
|
X <- data_gen(slice_size, score_min[slice], score_max[slice])
|
||||||
|
|
||||||
|
# Split into matricized mixture parts
|
||||||
|
matX <- Map(function(piece) {
|
||||||
|
`dim<-`(X[, , piece, ], c(64, slice_size))
|
||||||
|
}, pieces)
|
||||||
|
|
||||||
|
# accumulated loss over all piece mixtures
|
||||||
|
loss <- loss - Reduce(`+`, Map(function(matX, param, m2) {
|
||||||
|
sum(matX * (param %*% matX)) + slice_size * attr(m2, "log_prob_0")
|
||||||
|
}, matX, params, m2))
|
||||||
|
|
||||||
|
# Slice residuals (second order `resid2` and actual residuals `resid1`)
|
||||||
|
resid2 <- Map(function(matX, m2) {
|
||||||
|
tcrossprod(matX) - slice_size * m2
|
||||||
|
}, matX, m2)
|
||||||
|
|
||||||
|
# accumulate residuals
|
||||||
|
R2 <- Map(function(R2, resid2) { R2 + as.vector(resid2) }, R2, resid2)
|
||||||
|
|
||||||
|
# and the beta gradients
|
||||||
|
grad_betas <- Map(function(grad_betas, resid2, betas) {
|
||||||
|
resid1 <- `dim<-`(diag(resid2), dimX)
|
||||||
|
Map(`+`, grad_betas, Map(function(mode) {
|
||||||
|
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:2)[-mode]), mode)
|
||||||
|
}, 1:2))
|
||||||
|
}, grad_betas, resid2, betas)
|
||||||
|
}
|
||||||
|
|
||||||
|
# finaly, finish gradient computation with gradients for `Omegas`
|
||||||
|
grad_Omegas <- Map(function(R2, Omegas) {
|
||||||
|
Map(function(mode) {
|
||||||
|
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:2)[-mode], transposed = TRUE)
|
||||||
|
`dim<-`(grad, dim(Omegas[[mode]]))
|
||||||
|
}, 1:2)
|
||||||
|
}, R2, Omegas)
|
||||||
|
|
||||||
|
# Update tracker for break condition
|
||||||
|
non_improving <- if (best_loss < loss) non_improving + 1L else 0L
|
||||||
|
last_loss <- loss
|
||||||
|
best_loss <- min(best_loss, loss)
|
||||||
|
|
||||||
|
# check break condition
|
||||||
|
if (non_improving > patience) { break }
|
||||||
|
|
||||||
|
# accumulate root mean squared gradients
|
||||||
|
grad2_betas <- Map(function(grad2_betas, grad_betas) {
|
||||||
|
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
|
||||||
|
}, grad2_betas, grad_betas)
|
||||||
|
grad2_Omegas <- Map(function(grad2_Omegas, grad_Omegas) {
|
||||||
|
Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
|
||||||
|
}, grad2_Omegas, grad_Omegas)
|
||||||
|
|
||||||
|
# Update Parameters
|
||||||
|
betas <- Map(function(betas, grad_betas, grad2_betas) {
|
||||||
|
Map(function(beta, grad, M2) {
|
||||||
|
beta + (step_size / (sqrt(M2) + eps)) * grad
|
||||||
|
}, betas, grad_betas, grad2_betas)
|
||||||
|
}, betas, grad_betas, grad2_betas)
|
||||||
|
Omegas <- Map(function(Omegas, grad_Omegas, grad2_Omegas) {
|
||||||
|
Map(function(Omega, grad, M2) {
|
||||||
|
Omega + (step_size / (sqrt(M2) + eps)) * grad
|
||||||
|
}, Omegas, grad_Omegas, grad2_Omegas)
|
||||||
|
}, Omegas, grad_Omegas, grad2_Omegas)
|
||||||
|
|
||||||
|
# Log progress
|
||||||
|
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f (best: %f, non-improving: %d)\n",
|
||||||
|
iter, round(proc.time()[["elapsed"]] - start_time), loss, best_loss, non_improving))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save a final (terminal) save point
|
||||||
|
if (is.character(save_point)) {
|
||||||
|
suspendInterrupts(save(
|
||||||
|
dimX, dimF,
|
||||||
|
betas, Omegas,
|
||||||
|
grad2_betas, grad2_Omegas,
|
||||||
|
last_loss, best_loss, non_improving, iter,
|
||||||
|
file = sprintf(save_point, "final")))
|
||||||
|
}
|
||||||
|
|
||||||
|
structure(
|
||||||
|
list(betas = betas, Omegas = Omegas),
|
||||||
|
iter = iter, loss = loss
|
||||||
|
)
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <string>
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
#include "Board.h"
|
||||||
|
#include "Move.h"
|
||||||
|
#include "search.h"
|
||||||
|
#include "uci.h"
|
||||||
|
|
||||||
|
static const std::string usage{"usage: pgn2fen [--scored] [<input>]"};
|
||||||
|
|
||||||
|
// Convert PGN (Portable Game Notation) input stream to single FENs
|
||||||
|
// streamed to stdout
|
||||||
|
void pgn2fen(std::istream& input, const bool only_scored) {
|
||||||
|
|
||||||
|
// Instantiate Boards, the start of every game as well as the current state
|
||||||
|
// of the Board while processing a PGN game
|
||||||
|
Board startpos, pos;
|
||||||
|
|
||||||
|
// Read input line by line
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(input, line)) {
|
||||||
|
// Skip empty and metadata lines (every PGN game starts with "<nr>.")
|
||||||
|
if (line.empty() || line.front() == '[') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset position to the start position, every game starts here!
|
||||||
|
pos = startpos;
|
||||||
|
|
||||||
|
// Read game content (assuming one line is the entire game)
|
||||||
|
std::istringstream game(line);
|
||||||
|
std::string count, san, token, eval;
|
||||||
|
while (game >> count >> san >> token) {
|
||||||
|
// Consume/Parse PGN comments
|
||||||
|
if (only_scored) {
|
||||||
|
// consume the comment and search for an evaluation
|
||||||
|
bool has_score = false;
|
||||||
|
while (game >> token) {
|
||||||
|
// Search for evaluation token (position score _after_ the move)
|
||||||
|
if (token == "[%eval") {
|
||||||
|
game >> eval;
|
||||||
|
eval.pop_back(); // delete trailing ']'
|
||||||
|
has_score = true;
|
||||||
|
// Consume the remainder of the comment (ignore it)
|
||||||
|
std::getline(game, token, '}');
|
||||||
|
break;
|
||||||
|
} else if (token == "}") {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// In case of not finding an evaluation, skip the game (_not_ an error)
|
||||||
|
if (!has_score) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Consume the remainder of the comment (ignore it)
|
||||||
|
std::getline(game, token, '}');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform move
|
||||||
|
bool parseError = false;
|
||||||
|
Move move = UCI::parseSAN(san, pos, parseError);
|
||||||
|
if (parseError) {
|
||||||
|
std::cerr << "[ Error ] Parsing '" << san << "' at position '"
|
||||||
|
<< pos.fen() << "' failed." << std::endl;
|
||||||
|
}
|
||||||
|
move = pos.isLegal(move); // validate legality and extend move info
|
||||||
|
if (move) {
|
||||||
|
pos.make(move);
|
||||||
|
} else {
|
||||||
|
std::cerr << "[ Error ] Encountered illegal move '" << san
|
||||||
|
<< " (" << move
|
||||||
|
<< ") ' at position '" << pos.fen() << "'." << std::endl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write positions
|
||||||
|
if (only_scored) {
|
||||||
|
// Ingore "check mate in" scores (not relevant for eval training)
|
||||||
|
// Do this after "make move" in situations where the check mate
|
||||||
|
// was overlooked, leading to new positions
|
||||||
|
if (eval.length() && eval[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Otherwise, classic eval score to be parsed in centipawns
|
||||||
|
std::cout << pos.fen() << "; " << eval << '\n';
|
||||||
|
} else {
|
||||||
|
// Write only the position FEN
|
||||||
|
std::cout << pos.fen() << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int main(int argn, char* argv[]) {
|
||||||
|
// Setup control variables
|
||||||
|
bool only_scored = false;
|
||||||
|
std::string file = "";
|
||||||
|
|
||||||
|
// Parse command arguments
|
||||||
|
switch (argn) {
|
||||||
|
case 1:
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
if (std::string("--scored") == argv[1]) {
|
||||||
|
only_scored = true;
|
||||||
|
} else {
|
||||||
|
file = argv[1];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
if (std::string("--scored") != argv[1]) {
|
||||||
|
std::cout << usage << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
only_scored = true;
|
||||||
|
file = argv[2];
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
std::cout << usage << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke converter ether with file input or stdin
|
||||||
|
if (file == "") {
|
||||||
|
pgn2fen(std::cin, only_scored);
|
||||||
|
} else {
|
||||||
|
// Open input file
|
||||||
|
std::ifstream input(file);
|
||||||
|
if (!input) {
|
||||||
|
std::cerr << "Error opening '" << file << "'" << std::endl;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
pgn2fen(input, only_scored);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Data set name: Chess games from the Lichess Data Base for standard rated games
|
||||||
|
# in November 2023
|
||||||
|
data=lichess_db_standard_rated_2023-11
|
||||||
|
|
||||||
|
# Check if file exists and download iff not
|
||||||
|
if [ -f "${data}.fen" ]; then
|
||||||
|
echo "File '${data}.fen' already exists, assuming job already done."
|
||||||
|
echo "To rerun delete (rename) the files '${data}.pgn.zst' and/or '${data}.fen'"
|
||||||
|
else
|
||||||
|
# First, compile `png2fen`
|
||||||
|
make pgn2fen
|
||||||
|
|
||||||
|
# Download the PGN data base via `wegt` if not found.
|
||||||
|
# The flag `-q` suppresses `wget`s own output and `-O-` tells `wget` to
|
||||||
|
# stream the downloaded file to `stdout`.
|
||||||
|
# Otherwise, use the file on disk directly.
|
||||||
|
# Decompress the stream with `zstdcat` (no temporary files)
|
||||||
|
# The uncompressed PGN data is then piped into `pgn2fen` which converts
|
||||||
|
# the PGN data base into a list of FEN strings while filtering only
|
||||||
|
# positions with evaluation. The `--scored` parameter specifies to extract
|
||||||
|
# a position evaluation from the PGN and ONLY write positions with scores.
|
||||||
|
# That is, positions without a score are removed!
|
||||||
|
if [ -f "${data}.pgn.zst" ]; then
|
||||||
|
zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen
|
||||||
|
else
|
||||||
|
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \
|
||||||
|
| zstdcat | ./pgn2fen --scored > ${data}.fen
|
||||||
|
fi
|
||||||
|
fi
|
|
@ -0,0 +1,119 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
|
||||||
|
# Load as 3D predictors `X` and flat response `y` and `F = y` with per person dim. 1 x 1
|
||||||
|
c(X, F, y) %<-% local({
|
||||||
|
# Load from file
|
||||||
|
ds <- readRDS("eeg_data.rds")
|
||||||
|
|
||||||
|
# Dimension values
|
||||||
|
n <- nrow(ds) # sample size (nr. of people)
|
||||||
|
p <- 64L # nr. of predictors (count of sensorce)
|
||||||
|
t <- 256L # nr. of time points (measurements)
|
||||||
|
|
||||||
|
# Extract dimension names
|
||||||
|
nNames <- ds$PersonID
|
||||||
|
tNames <- as.character(seq(t))
|
||||||
|
pNames <- unlist(strsplit(colnames(ds)[2 + t * seq(p)], "_"))[c(TRUE, FALSE)]
|
||||||
|
|
||||||
|
# Split into predictors (with proper dims and names) and response
|
||||||
|
X <- array(as.matrix(ds[, -(1:2)]),
|
||||||
|
dim = c(person = n, time = t, sensor = p),
|
||||||
|
dimnames = list(person = nNames, time = tNames, sensor = pNames)
|
||||||
|
)
|
||||||
|
y <- ds$Case_Control
|
||||||
|
|
||||||
|
list(X, array(y, c(n, 1L, 1L)), y)
|
||||||
|
})
|
||||||
|
|
||||||
|
# fit a tensor normal model to the data sample axis 1 indexes persons)
|
||||||
|
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = 1L)
|
||||||
|
|
||||||
|
# plot the fitted mode wise reductions (for time and sensor axis)
|
||||||
|
with(fit.gmlm, {
|
||||||
|
par.reset <- par(mfrow = c(2, 1))
|
||||||
|
plot(seq(0, 1, len = 256), betas[[1]], main = "Time", xlab = "Time [s]", ylab = expression(beta[1]))
|
||||||
|
plot(betas[[2]], main = "Sensors", xlab = "Sensor Index", ylab = expression(beta[2]))
|
||||||
|
par(par.reset)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
#' (2D)^2 PCA preprocessing
|
||||||
|
#'
|
||||||
|
#' @param tpc Number of "t"ime "p"rincipal "c"omponents.
|
||||||
|
#' @param ppc Number of "p"redictor "p"rincipal "c"omponents.
|
||||||
|
preprocess <- function(X, tpc, ppc) {
|
||||||
|
# Mode covariances (for predictor and time point modes)
|
||||||
|
c(Sigma_t, Sigma_p) %<-% mcov(X, sample.axis = 1L)
|
||||||
|
|
||||||
|
# "predictor" (sensor) and time point principal components
|
||||||
|
V_t <- svd(Sigma_t, tpc, 0L)$u
|
||||||
|
V_p <- svd(Sigma_p, ppc, 0L)$u
|
||||||
|
|
||||||
|
# reduce with mode wise PCs
|
||||||
|
mlm(X, list(V_t, V_p), modes = 2:3, transposed = TRUE)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#' Leave-one-out prediction
|
||||||
|
#'
|
||||||
|
#' @param X 3D EEG data (preprocessed or not)
|
||||||
|
#' @param F binary responce `y` as a 3D tensor, every obs. is a 1 x 1 matrix
|
||||||
|
loo.predict <- function(X, F) {
|
||||||
|
sapply(seq_len(dim(X)[1L]), function(i) {
|
||||||
|
# Fit with i'th observation removes
|
||||||
|
fit <- gmlm_tensor_normal(X[-i, , ], F[-i, , , drop = FALSE], sample.axis = 1L)
|
||||||
|
|
||||||
|
# Reduce the entire data set
|
||||||
|
r <- as.vector(mlm(X, fit$betas, modes = 2:3, transpose = TRUE))
|
||||||
|
# Fit a logit model on reduced data with i'th observation removed
|
||||||
|
logit <- glm(y ~ r, family = binomial(link = "logit"),
|
||||||
|
data = data.frame(y = y[-i], r = r[-i])
|
||||||
|
)
|
||||||
|
# predict i'th response given i'th reduced observation
|
||||||
|
y.hat <- predict(logit, newdata = data.frame(r = r[i]), type = "response")
|
||||||
|
# report progress
|
||||||
|
cat(sprintf("dim: (%d, %d) - %3d/%d\n", dim(X)[2L], dim(X)[3L], i, dim(X)[1L]))
|
||||||
|
|
||||||
|
y.hat
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Classification performance measures
|
||||||
|
# acc: Accuracy. P(Yhat = Y). Estimated as: (TP+TN)/(P+N).
|
||||||
|
acc <- function(y.true, y.pred) mean(round(y.pred) == y.true)
|
||||||
|
# err: Error rate. P(Yhat != Y). Estimated as: (FP+FN)/(P+N).
|
||||||
|
err <- function(y.true, y.pred) mean(round(y.pred) != y.true)
|
||||||
|
# fpr: False positive rate. P(Yhat = + | Y = -). aliases: Fallout.
|
||||||
|
fpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 0])
|
||||||
|
# tpr: True positive rate. P(Yhat = + | Y = +). aliases: Sensitivity, Recall.
|
||||||
|
tpr <- function(y.true, y.pred) mean((round(y.pred) == 1)[y.true == 1])
|
||||||
|
# fnr: False negative rate. P(Yhat = - | Y = +). aliases: Miss.
|
||||||
|
fnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 1])
|
||||||
|
# tnr: True negative rate. P(Yhat = - | Y = -).
|
||||||
|
tnr <- function(y.true, y.pred) mean((round(y.pred) == 0)[y.true == 0])
|
||||||
|
# auc: Area Under the Curve
|
||||||
|
auc <- function(y.true, y.pred) as.numeric(pROC::roc(y.true, y.pred, quiet = TRUE)$auc)
|
||||||
|
auc.sd <- function(y.true, y.pred) sqrt(pROC::var(pROC::roc(y.true, y.pred, quiet = TRUE)))
|
||||||
|
|
||||||
|
|
||||||
|
# perform preprocessed (reduced) and raw (not reduced) leave-one-out prediction
|
||||||
|
y.hat.3.4 <- loo.predict(preprocess(X, 3, 4), F)
|
||||||
|
y.hat.15.15 <- loo.predict(preprocess(X, 15, 15), F)
|
||||||
|
y.hat.20.30 <- loo.predict(preprocess(X, 20, 30), F)
|
||||||
|
y.hat <- loo.predict(X, F)
|
||||||
|
|
||||||
|
# classification performance measures table by leave-one-out cross-validation
|
||||||
|
(loo.cv <- apply(cbind(y.hat.3.4, y.hat.15.15, y.hat.20.30, y.hat), 2, function(y.pred) {
|
||||||
|
sapply(c("acc", "err", "fpr", "tpr", "fnr", "tnr", "auc", "auc.sd"),
|
||||||
|
function(FUN) { match.fun(FUN)(y, y.pred) })
|
||||||
|
}))
|
||||||
|
#> y.hat.3.4 y.hat.15.15 y.hat.20.30 y.hat
|
||||||
|
#> acc 0.79508197 0.78688525 0.78688525 0.78688525
|
||||||
|
#> err 0.20491803 0.21311475 0.21311475 0.21311475
|
||||||
|
#> fpr 0.35555556 0.40000000 0.40000000 0.40000000
|
||||||
|
#> tpr 0.88311688 0.89610390 0.89610390 0.89610390
|
||||||
|
#> fnr 0.11688312 0.10389610 0.10389610 0.10389610
|
||||||
|
#> tnr 0.64444444 0.60000000 0.60000000 0.60000000
|
||||||
|
#> auc 0.85108225 0.83838384 0.83924964 0.83896104
|
||||||
|
#> auc.sd 0.03584791 0.03760531 0.03751307 0.03754553
|
134
sim/sim-tsir.R
134
sim/sim-tsir.R
|
@ -1,35 +1,49 @@
|
||||||
library(tensorPredictors)
|
library(tensorPredictors)
|
||||||
|
suppressPackageStartupMessages(library(Rdimtools))
|
||||||
|
|
||||||
|
# Source utility function used in most simulations (extracted for convenience)
|
||||||
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
source("./sim_utils.R")
|
||||||
|
|
||||||
# Data set sample size in every simulation
|
# Data set sample size in every simulation
|
||||||
sample.size <- 500L
|
sample.size <- 500L
|
||||||
# Nr. of per simulation replications
|
# Nr. of per simulation replications
|
||||||
reps <- 100L
|
reps <- 10L
|
||||||
# number of observation/response axes (order of the tensors)
|
# number of observation/response axes (order of the tensors)
|
||||||
orders <- c(2L, 3L, 4L)
|
orders <- c(2L, 3L, 4L)
|
||||||
# auto correlation coefficient for the mode-wise auto scatter matrices `Omegas`
|
# auto correlation coefficient for the mode-wise auto scatter matrices `Omegas`
|
||||||
rhos <- seq(0, 0.8, by = 0.1)
|
rhos <- seq(0, 0.8, by = 0.2)
|
||||||
|
|
||||||
|
|
||||||
setwd("~/Work/tensorPredictors/sim/")
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
base.name <- format(Sys.time(), "failure_of_tsir-%Y%m%dT%H%M")
|
base.name <- format(Sys.time(), "sim-tsir-%Y%m%dT%H%M")
|
||||||
|
|
||||||
# data sampling routine
|
# data sampling routine
|
||||||
sample.data <- function(sample.size, betas, Omegas) {
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
dimF <- mapply(ncol, betas)
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
# responce is a standard normal variable
|
# responce is a standard normal variable
|
||||||
y <- rnorm(sample.size)
|
y <- sort(rnorm(sample.size))
|
||||||
y.pow <- Reduce(function(a, b) outer(a, b, `+`), Map(seq, 0L, len = dimF))
|
y.pow <- Reduce(function(a, b) outer(a, b, `+`), Map(seq, 0L, len = dimF))
|
||||||
F <- t(outer(y, as.vector(y.pow), `^`))
|
F <- t(outer(y, as.vector(y.pow), `^`)) / as.vector(factorial(y.pow))
|
||||||
dim(F) <- c(dimF, sample.size)
|
dim(F) <- c(dimF, sample.size)
|
||||||
|
|
||||||
|
matplot(mat(F, length(dim(F))), type = "l")
|
||||||
|
abline(h = 0, lty = "dashed")
|
||||||
|
|
||||||
|
matplot(y, scale(mat(F, length(dim(F))), scale = FALSE), type = "l")
|
||||||
|
abline(h = 0, lty = "dashed")
|
||||||
|
|
||||||
|
|
||||||
# sample predictors from tensor normal X | Y = y (last axis is sample axis)
|
# sample predictors from tensor normal X | Y = y (last axis is sample axis)
|
||||||
sample.axis <- length(betas) + 1L
|
sample.axis <- length(betas) + 1L
|
||||||
Sigmas <- Map(solve, Omegas)
|
Sigmas <- Map(solve, Omegas)
|
||||||
mu_y <- mlm(F, Map(`%*%`, Sigmas, betas))
|
mu_y <- mlm(F, Map(`%*%`, Sigmas, betas))
|
||||||
X <- mu_y + rtensornorm(sample.size, 0, Sigmas, sample.axis)
|
X <- mu_y + rtensornorm(sample.size, 0, Sigmas, sample.axis)
|
||||||
|
|
||||||
list(X = X, F = F, y = y, sample.axis = sample.axis)
|
# Make `y` into a `Y` tensor with variable axis all of dim 1
|
||||||
|
Y <- array(y, dim = c(rep(1L, length(dimF)), sample.size))
|
||||||
|
|
||||||
|
list(X = X, F = F, Y = Y, sample.axis = sample.axis)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create a CSV logger to write simulation results to
|
# Create a CSV logger to write simulation results to
|
||||||
|
@ -37,7 +51,7 @@ log.file <- paste(base.name, "csv", sep = ".")
|
||||||
logger <- CSV.logger(
|
logger <- CSV.logger(
|
||||||
file.name = log.file,
|
file.name = log.file,
|
||||||
header = c("rho", "order", "sample.size", "rep", "beta.version", outer(
|
header = c("rho", "order", "sample.size", "rep", "beta.version", outer(
|
||||||
"dist.subspace", c("gmlm", "tsir", "sir"),
|
"dist.subspace", c("gmlm", "gmlm.1d", "tsir", "sir"),
|
||||||
paste, sep = "."
|
paste, sep = "."
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
|
@ -66,22 +80,25 @@ for (order in orders) {
|
||||||
# Version 1: repeated simulations
|
# Version 1: repeated simulations
|
||||||
for (rep in seq_len(reps)) {
|
for (rep in seq_len(reps)) {
|
||||||
# Sample training data
|
# Sample training data
|
||||||
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
c(X, F, Y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
# Fit models to provided data
|
# Fit models to provided data
|
||||||
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = sample.axis, proj.betas = proj.betas)
|
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = sample.axis, proj.betas = proj.betas)
|
||||||
fit.tsir <- TSIR(X, y, d = rep(1L, order), sample.axis = sample.axis)
|
fit.gmlm.y <- gmlm_tensor_normal(X, Y, sample.axis = sample.axis)
|
||||||
fit.sir <- SIR(mat(X, sample.axis), y, d = 1L)
|
fit.tsir <- TSIR(X, drop(Y), d = rep(1L, order), sample.axis = sample.axis)
|
||||||
|
fit.sir <- do.sir(mat(X, sample.axis), drop(Y), ndim = 1L)
|
||||||
|
|
||||||
# Extract minimal reduction matrices from fitted models
|
# Extract minimal reduction matrices from fitted models
|
||||||
B.gmlm <- qr.Q(qr(Reduce(kronecker, rev(fit.gmlm$betas))))[, 1L, drop = FALSE]
|
B.gmlm <- qr.Q(qr(Reduce(kronecker, rev(fit.gmlm$betas))))[, 1L, drop = FALSE]
|
||||||
B.tsir <- Reduce(kronecker, rev(fit.tsir))
|
B.gmlm.y <- Reduce(kronecker, rev(fit.gmlm.y$betas))
|
||||||
B.sir <- fit.sir
|
B.tsir <- Reduce(kronecker, rev(fit.tsir))
|
||||||
|
B.sir <- fit.sir$projection
|
||||||
|
|
||||||
# Compute estimation to true minimal `B` distance
|
# Compute estimation to true minimal `B` distance
|
||||||
dist.subspace.gmlm <- dist.subspace(B.min.true, B.gmlm, normalize = TRUE)
|
dist.subspace.gmlm <- dist.subspace(B.min.true, B.gmlm, normalize = TRUE)
|
||||||
dist.subspace.tsir <- dist.subspace(B.min.true, B.tsir, normalize = TRUE)
|
dist.subspace.gmlm.y <- dist.subspace(B.min.true, B.gmlm.y, normalize = TRUE)
|
||||||
dist.subspace.sir <- dist.subspace(B.min.true, B.sir, normalize = TRUE)
|
dist.subspace.tsir <- dist.subspace(B.min.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.sir <- dist.subspace(B.min.true, B.sir, normalize = TRUE)
|
||||||
|
|
||||||
# Write to simulation log file (CSV file)
|
# Write to simulation log file (CSV file)
|
||||||
logger()
|
logger()
|
||||||
|
@ -104,22 +121,25 @@ for (order in orders) {
|
||||||
# Version 2: repeated simulations (identical to Version 1)
|
# Version 2: repeated simulations (identical to Version 1)
|
||||||
for (rep in seq_len(reps)) {
|
for (rep in seq_len(reps)) {
|
||||||
# Sample training data
|
# Sample training data
|
||||||
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
c(X, F, Y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
# Fit models to provided data
|
# Fit models to provided data
|
||||||
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = sample.axis, proj.betas = proj.betas)
|
fit.gmlm <- gmlm_tensor_normal(X, F, sample.axis = sample.axis, proj.betas = proj.betas)
|
||||||
fit.tsir <- TSIR(X, y, d = rep(1L, order), sample.axis = sample.axis)
|
fit.gmlm.y <- gmlm_tensor_normal(X, Y, sample.axis = sample.axis)
|
||||||
fit.sir <- SIR(mat(X, sample.axis), y, d = 1L)
|
fit.tsir <- TSIR(X, drop(Y), d = rep(1L, order), sample.axis = sample.axis)
|
||||||
|
fit.sir <- do.sir(mat(X, sample.axis), drop(Y), ndim = 1L)
|
||||||
|
|
||||||
# Extract minimal reduction matrices from fitted models
|
# Extract minimal reduction matrices from fitted models
|
||||||
B.gmlm <- qr.Q(qr(Reduce(kronecker, rev(fit.gmlm$betas))))[, 1L, drop = FALSE]
|
B.gmlm <- qr.Q(qr(Reduce(kronecker, rev(fit.gmlm$betas))))[, 1L, drop = FALSE]
|
||||||
B.tsir <- Reduce(kronecker, rev(fit.tsir))
|
B.gmlm.y <- Reduce(kronecker, rev(fit.gmlm.y$betas))
|
||||||
B.sir <- fit.sir
|
B.tsir <- Reduce(kronecker, rev(fit.tsir))
|
||||||
|
B.sir <- fit.sir$projection
|
||||||
|
|
||||||
# Compute estimation to true minimal `B` distance
|
# Compute estimation to true minimal `B` distance
|
||||||
dist.subspace.gmlm <- dist.subspace(B.min.true, B.gmlm, normalize = TRUE)
|
dist.subspace.gmlm <- dist.subspace(B.min.true, B.gmlm, normalize = TRUE)
|
||||||
dist.subspace.tsir <- dist.subspace(B.min.true, B.tsir, normalize = TRUE)
|
dist.subspace.gmlm.y <- dist.subspace(B.min.true, B.gmlm.y, normalize = TRUE)
|
||||||
dist.subspace.sir <- dist.subspace(B.min.true, B.sir, normalize = TRUE)
|
dist.subspace.tsir <- dist.subspace(B.min.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.sir <- dist.subspace(B.min.true, B.sir, normalize = TRUE)
|
||||||
|
|
||||||
# Write to simulation log file (CSV file)
|
# Write to simulation log file (CSV file)
|
||||||
logger()
|
logger()
|
||||||
|
@ -155,6 +175,13 @@ layout(rbind(
|
||||||
2 * length(orders) + 1
|
2 * length(orders) + 1
|
||||||
), heights = c(rep(6L, length(orders)), 1L))
|
), heights = c(rep(6L, length(orders)), 1L))
|
||||||
|
|
||||||
|
col.methods <- c(
|
||||||
|
gmlm = "#000000",
|
||||||
|
gmlm.y = "#FF0000",
|
||||||
|
tsir = "#009E73",
|
||||||
|
sir = "#999999"
|
||||||
|
)
|
||||||
|
|
||||||
for (group in split(aggr, aggr[c("order", "beta.version")])) {
|
for (group in split(aggr, aggr[c("order", "beta.version")])) {
|
||||||
order <- group$order[[1]]
|
order <- group$order[[1]]
|
||||||
beta.version <- group$beta.version[[1]]
|
beta.version <- group$beta.version[[1]]
|
||||||
|
@ -166,9 +193,10 @@ for (group in split(aggr, aggr[c("order", "beta.version")])) {
|
||||||
axis(1, at = rho)
|
axis(1, at = rho)
|
||||||
axis(2, at = seq(0, 1, by = 0.2))
|
axis(2, at = seq(0, 1, by = 0.2))
|
||||||
with(group, {
|
with(group, {
|
||||||
lines(rho, dist.subspace.gmlm, col = col.methods["gmlm"], lwd = 3, type = "b", pch = 16)
|
lines(rho, dist.subspace.gmlm, col = col.methods["gmlm"], lwd = 3, type = "b", pch = 16)
|
||||||
lines(rho, dist.subspace.tsir, col = col.methods["tsir"], lwd = 2, type = "b", pch = 16)
|
lines(rho, dist.subspace.gmlm.y, col = col.methods["gmlm.y"], lwd = 3, type = "b", pch = 16)
|
||||||
lines(rho, dist.subspace.sir, col = col.methods["sir"], lwd = 2, type = "b", pch = 16)
|
lines(rho, dist.subspace.tsir, col = col.methods["tsir"], lwd = 2, type = "b", pch = 16)
|
||||||
|
lines(rho, dist.subspace.sir, col = col.methods["sir"], lwd = 2, type = "b", pch = 16)
|
||||||
})
|
})
|
||||||
if (order == 3L && beta.version == 2L) {
|
if (order == 3L && beta.version == 2L) {
|
||||||
abline(v = 0.5, lty = "dotted", col = "black")
|
abline(v = 0.5, lty = "dotted", col = "black")
|
||||||
|
@ -176,49 +204,7 @@ for (group in split(aggr, aggr[c("order", "beta.version")])) {
|
||||||
lty = "dotted", col = "black")
|
lty = "dotted", col = "black")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
methods <- c("GMLM", "TSIR", "SIR")
|
methods <- c("GMLM", "GMLM.y", "TSIR", "SIR")
|
||||||
restor.par <- par(
|
|
||||||
fig = c(0, 1, 0, 1),
|
|
||||||
oma = c(0, 0, 0, 0),
|
|
||||||
mar = c(1, 0, 0, 0),
|
|
||||||
new = TRUE
|
|
||||||
)
|
|
||||||
plot(0, type = "n", bty = "n", axes = FALSE, xlab = "", ylab = "")
|
|
||||||
legend("bottom", col = col.methods[tolower(methods)], legend = methods,
|
|
||||||
horiz = TRUE, lty = 1, bty = "n", lwd = c(3, 2, 2), pch = 16)
|
|
||||||
par(restor.par)
|
|
||||||
|
|
||||||
|
|
||||||
# new grouping for the aggregates
|
|
||||||
layout(rbind(
|
|
||||||
matrix(seq_len(2 * 3), ncol = 2),
|
|
||||||
2 * 3 + 1
|
|
||||||
), heights = c(rep(6L, 3), 1L))
|
|
||||||
|
|
||||||
for (group in split(aggr, aggr[c("rho", "beta.version")])) {
|
|
||||||
rho <- group$rho[[1]]
|
|
||||||
beta.version <- group$beta.version[[1]]
|
|
||||||
|
|
||||||
if (!(rho %in% c(0, .5, .8))) { next }
|
|
||||||
|
|
||||||
order <- group$order
|
|
||||||
|
|
||||||
plot(range(order), 0:1, main = sprintf("V%d, rho %.1f", beta.version, rho),
|
|
||||||
type = "n", bty = "n", axes = FALSE, xlab = expression(order), ylab = "Subspace Distance")
|
|
||||||
axis(1, at = order)
|
|
||||||
axis(2, at = seq(0, 1, by = 0.2))
|
|
||||||
with(group, {
|
|
||||||
lines(order, dist.subspace.gmlm, col = col.methods["gmlm"], lwd = 3, type = "b", pch = 16)
|
|
||||||
lines(order, dist.subspace.tsir, col = col.methods["tsir"], lwd = 2, type = "b", pch = 16)
|
|
||||||
lines(order, dist.subspace.sir, col = col.methods["sir"], lwd = 2, type = "b", pch = 16)
|
|
||||||
})
|
|
||||||
if (rho == 0.5 && beta.version == 2L) {
|
|
||||||
abline(v = 0.5, lty = "dotted", col = "black")
|
|
||||||
abline(h = group$dist.subspace.tsir[which(order == 3L)],
|
|
||||||
lty = "dotted", col = "black")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
methods <- c("GMLM", "TSIR", "SIR")
|
|
||||||
restor.par <- par(
|
restor.par <- par(
|
||||||
fig = c(0, 1, 0, 1),
|
fig = c(0, 1, 0, 1),
|
||||||
oma = c(0, 0, 0, 0),
|
oma = c(0, 0, 0, 0),
|
||||||
|
|
|
@ -0,0 +1,197 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
library(logisticPCA)
|
||||||
|
# library(RGCCA)
|
||||||
|
# Use modified version of `RGCCA`
|
||||||
|
# Reasons (on Ubuntu 22.04 LTS):
|
||||||
|
# - compatible with `Rscript`
|
||||||
|
# - about 4 times faster for small problems
|
||||||
|
# Changes:
|
||||||
|
# - Run in main thread, avoid `parallel::makeCluster` if `n_cores == 1`
|
||||||
|
# (file "R/mgccak.R" lines 81:103)
|
||||||
|
# - added `Encoding: UTF-8`
|
||||||
|
# (file "DESCRIPTION")
|
||||||
|
suppressWarnings({
|
||||||
|
devtools::load_all("~/Work/tensorPredictors/References/Software/TGCCA-modified", export_all = FALSE)
|
||||||
|
})
|
||||||
|
|
||||||
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
base.name <- format(Sys.time(), "sim_2b_ising-%Y%m%dT%H%M")
|
||||||
|
|
||||||
|
# Source utility function used in most simulations (extracted for convenience)
|
||||||
|
source("./sim_utils.R")
|
||||||
|
|
||||||
|
# Set PRNG seed for reproducability
|
||||||
|
# Note: `0x` is the HEX number prefix and the trailing `L` stands for "long"
|
||||||
|
# which is `R`s way if indicating an integer.
|
||||||
|
set.seed(seed <- 0x2bL, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
|
||||||
|
|
||||||
|
reps <- 100 # number of simulation replications
|
||||||
|
sample.sizes <- c(100, 200, 300, 500, 750) # sample sizes `n`
|
||||||
|
dimX <- c(2, 3) # predictor `X` dimension
|
||||||
|
dimF <- c(2, 2) # "function" `F(y)` of responce `y` dimension
|
||||||
|
|
||||||
|
betas <- Map(diag, 1, dimX, dimF)
|
||||||
|
Omegas <- list(toeplitz(c(0, -2)), toeplitz(seq(1, 0, by = -0.5)))
|
||||||
|
|
||||||
|
# compute true (full) model parameters to compair estimates against
|
||||||
|
B.true <- Reduce(`%x%`, rev(betas))
|
||||||
|
|
||||||
|
# data sampling routine
|
||||||
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
|
dimX <- mapply(nrow, betas)
|
||||||
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
|
# generate response (sample axis is last axis)
|
||||||
|
y <- runif(sample.size, -1, 1)
|
||||||
|
F <- aperm(array(c(
|
||||||
|
sinpi(y), -cospi(y),
|
||||||
|
cospi(y), sinpi(y)
|
||||||
|
), dim = c(length(y), 2, 2)), c(2, 3, 1))
|
||||||
|
|
||||||
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
X <- apply(F, 3, function(Fi) {
|
||||||
|
dim(Fi) <- dimF
|
||||||
|
params <- diag(as.vector(mlm(Fi, betas))) + Omega
|
||||||
|
tensorPredictors::ising_sample(1, params)
|
||||||
|
})
|
||||||
|
dim(X) <- c(dimX, sample.size)
|
||||||
|
|
||||||
|
list(X = X, F = F, y = y, sample.axis = 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
# # has been run once with initial seed
|
||||||
|
# lpca.hyper.param <- local({
|
||||||
|
# c(X, F, y, sample.axis) %<-% sample.data(1e3, betas, Omegas)
|
||||||
|
# vecX <- mat(X, sample.axis)
|
||||||
|
# CV <- cv.lpca(vecX, ks = prod(dimF), ms = seq(1, 30, by = 1))
|
||||||
|
# # plot(CV)
|
||||||
|
# as.numeric(colnames(CV))[which.min(CV)]
|
||||||
|
# })
|
||||||
|
# set.seed(seed, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
lpca.hyper.param <- 23
|
||||||
|
|
||||||
|
|
||||||
|
# Create a CSV logger to write simulation results to
|
||||||
|
log.file <- paste(base.name, "csv", sep = ".")
|
||||||
|
logger <- CSV.logger(
|
||||||
|
file.name = log.file,
|
||||||
|
header = c("sample.size", "rep", outer(
|
||||||
|
c("time", "dist.subspace", "dist.projection"), # < measures, v methods
|
||||||
|
c("gmlm", "tnormal", "pca", "hopca", "lpca", "clpca", "tsir", "mgcca"),
|
||||||
|
paste, sep = "."
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### for each sample size
|
||||||
|
for (sample.size in sample.sizes) {
|
||||||
|
# repeate every simulation
|
||||||
|
for (rep in seq_len(reps)) {
|
||||||
|
# Sample training data
|
||||||
|
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
|
# start timing for reporting
|
||||||
|
start.timer()
|
||||||
|
|
||||||
|
# fit different models
|
||||||
|
# Wrapped in try-catch clock to ensure the simulation continues,
|
||||||
|
# if an error occures continue with nest resplication and log an error message
|
||||||
|
try.catch.block <- tryCatch({
|
||||||
|
time.gmlm <- system.time(
|
||||||
|
fit.gmlm <- gmlm_ising(X, F, y, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
time.tnormal <- -1 # part of Ising gmlm (not relevent here)
|
||||||
|
time.pca <- system.time(
|
||||||
|
fit.pca <- prcomp(mat(X, sample.axis), rank. = prod(dimF))
|
||||||
|
)["user.self"]
|
||||||
|
time.hopca <- system.time(
|
||||||
|
fit.hopca <- HOPCA(X, npc = dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
time.lpca <- system.time(
|
||||||
|
fit.lpca <- logisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.clpca <- system.time(
|
||||||
|
fit.clpca <- convexLogisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.tsir <- system.time(
|
||||||
|
fit.tsir <- TSIR(X, y, dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
# `mgcca` expects the first axis to be the sample axis
|
||||||
|
X1 <- aperm(X, c(sample.axis, seq_along(dim(X))[-sample.axis]))
|
||||||
|
F1 <- cbind(sinpi(y), cospi(y))
|
||||||
|
time.mgcca <- system.time(
|
||||||
|
fit.mgcca <- mgcca(list(X1, F1), ncomp = c(prod(dimF), 1),
|
||||||
|
quiet = TRUE, scheme = "factorial")
|
||||||
|
)["user.self"]
|
||||||
|
}, error = print)
|
||||||
|
|
||||||
|
# Get elapsed time from last timer start and the accumulated total time
|
||||||
|
# (_not_ a precide timing, only to get an idea)
|
||||||
|
c(elapsed, total.time) %<-% end.timer()
|
||||||
|
|
||||||
|
# Drop comparison in case any error (in any fitting routine)
|
||||||
|
if (inherits(try.catch.block, "error")) { next }
|
||||||
|
|
||||||
|
# Compute true reduction matrix
|
||||||
|
B.gmlm <- with(fit.gmlm, Reduce(`%x%`, rev(betas)))
|
||||||
|
B.tnormal <- with(attr(fit.gmlm, "tensor_normal"), Reduce(`%x%`, rev(betas)))
|
||||||
|
B.pca <- fit.pca$rotation
|
||||||
|
B.hopca <- Reduce(`%x%`, rev(fit.hopca))
|
||||||
|
B.lpca <- fit.lpca$U
|
||||||
|
B.clpca <- fit.clpca$U
|
||||||
|
B.tsir <- Reduce(`%x%`, rev(fit.tsir))
|
||||||
|
B.mgcca <- fit.mgcca$astar[[1]]
|
||||||
|
|
||||||
|
# Subspace Distances: Normalized `|| P_A - P_B ||_F` where
|
||||||
|
# `P_A = A (A' A)^-1 A'` and the normalization means that with
|
||||||
|
# respect to the dimensions of `A, B` the subspace distance is in the
|
||||||
|
# range `[0, 1]`.
|
||||||
|
dist.subspace.gmlm <- dist.subspace(B.true, B.gmlm, normalize = TRUE)
|
||||||
|
dist.subspace.tnormal <- dist.subspace(B.true, B.tnormal, normalize = TRUE)
|
||||||
|
dist.subspace.pca <- dist.subspace(B.true, B.pca, normalize = TRUE)
|
||||||
|
dist.subspace.hopca <- dist.subspace(B.true, B.hopca, normalize = TRUE)
|
||||||
|
dist.subspace.lpca <- dist.subspace(B.true, B.lpca, normalize = TRUE)
|
||||||
|
dist.subspace.clpca <- dist.subspace(B.true, B.clpca, normalize = TRUE)
|
||||||
|
dist.subspace.tsir <- dist.subspace(B.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.mgcca <- dist.subspace(B.true, B.mgcca, normalize = TRUE)
|
||||||
|
|
||||||
|
# # Projection Distances: Spectral norm (2-norm) `|| P_A - P_B ||_2`.
|
||||||
|
dist.projection.gmlm <- dist.projection(B.true, B.gmlm)
|
||||||
|
dist.projection.tnormal <- dist.projection(B.true, B.tnormal)
|
||||||
|
dist.projection.pca <- dist.projection(B.true, B.pca)
|
||||||
|
dist.projection.hopca <- dist.projection(B.true, B.hopca)
|
||||||
|
dist.projection.lpca <- dist.projection(B.true, B.lpca)
|
||||||
|
dist.projection.clpca <- dist.projection(B.true, B.clpca)
|
||||||
|
dist.projection.tsir <- dist.projection(B.true, B.tsir)
|
||||||
|
dist.projection.mgcca <- dist.projection(B.true, B.mgcca)
|
||||||
|
|
||||||
|
# Call CSV logger writing results to file
|
||||||
|
logger()
|
||||||
|
|
||||||
|
# print progress
|
||||||
|
cat(sprintf("sample size (%d): %d/%d - rep: %d/%d - elapsed: %.1f [s], total: %.0f [s]\n",
|
||||||
|
sample.size, which(sample.size == sample.sizes),
|
||||||
|
length(sample.sizes), rep, reps, elapsed, total.time))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### read simulation results generate plots
|
||||||
|
if (!interactive()) { pdf(file = paste(base.name, "pdf", sep = ".")) }
|
||||||
|
|
||||||
|
sim <- read.csv(log.file)
|
||||||
|
|
||||||
|
|
||||||
|
plot.sim(sim, "dist.subspace", main = "Subspace Distance",
|
||||||
|
xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
plot.sim(sim, "dist.projection", main = "Projection Distance",
|
||||||
|
xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
plot.sim(sim, "time", main = "Runtime",
|
||||||
|
xlab = "Sample Size", ylab = "Time [s]",
|
||||||
|
ylim = c(0, max(sim[startsWith(names(sim), "time")])))
|
|
@ -0,0 +1,195 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
library(logisticPCA)
|
||||||
|
# library(RGCCA)
|
||||||
|
# Use modified version of `RGCCA`
|
||||||
|
# Reasons (on Ubuntu 22.04 LTS):
|
||||||
|
# - compatible with `Rscript`
|
||||||
|
# - about 4 times faster for small problems
|
||||||
|
# Changes:
|
||||||
|
# - Run in main thread, avoid `parallel::makeCluster` if `n_cores == 1`
|
||||||
|
# (file "R/mgccak.R" lines 81:103)
|
||||||
|
# - added `Encoding: UTF-8`
|
||||||
|
# (file "DESCRIPTION")
|
||||||
|
suppressWarnings({
|
||||||
|
devtools::load_all("~/Work/tensorPredictors/References/Software/TGCCA-modified", export_all = FALSE)
|
||||||
|
})
|
||||||
|
|
||||||
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
base.name <- format(Sys.time(), "sim_2c_ising-%Y%m%dT%H%M")
|
||||||
|
|
||||||
|
# Source utility function used in most simulations (extracted for convenience)
|
||||||
|
source("./sim_utils.R")
|
||||||
|
|
||||||
|
# Set PRNG seed for reproducability
|
||||||
|
# Note: `0x` is the HEX number prefix and the trailing `L` stands for "long"
|
||||||
|
# which is `R`s way if indicating an integer.
|
||||||
|
set.seed(seed <- 0x2cL, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
|
||||||
|
|
||||||
|
reps <- 100 # number of simulation replications
|
||||||
|
sample.sizes <- c(100, 200, 300, 500, 750) # sample sizes `n`
|
||||||
|
dimX <- c(2, 3) # predictor `X` dimension
|
||||||
|
dimF <- c(2, 2) # "function" `F(y)` of responce `y` dimension
|
||||||
|
|
||||||
|
betas <- list(
|
||||||
|
`[<-`(matrix(0, dimX[1], dimF[1]), 1, , c(1, 1)),
|
||||||
|
`[<-`(matrix(0, dimX[2], dimF[2]), 2, , c(1, -1))
|
||||||
|
)
|
||||||
|
Omegas <- list(toeplitz(c(0, -2)), toeplitz(seq(1, 0, by = -0.5)))
|
||||||
|
|
||||||
|
# compute true (full) model parameters to compair estimates against
|
||||||
|
B.true <- as.matrix(as.numeric((1:6) == 3))
|
||||||
|
|
||||||
|
# define projections onto rank 1 matrices for betas
|
||||||
|
proj.betas <- Map(.projRank, rep(1L, length(betas)))
|
||||||
|
|
||||||
|
|
||||||
|
# data sampling routine
|
||||||
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
|
dimX <- mapply(nrow, betas)
|
||||||
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
|
# generate response (sample axis is last axis)
|
||||||
|
y <- runif(sample.size, -1, 1)
|
||||||
|
F <- aperm(array(c(
|
||||||
|
sinpi(y), -cospi(y),
|
||||||
|
cospi(y), sinpi(y)
|
||||||
|
), dim = c(length(y), 2, 2)), c(2, 3, 1))
|
||||||
|
|
||||||
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
X <- apply(F, 3, function(Fi) {
|
||||||
|
dim(Fi) <- dimF
|
||||||
|
params <- diag(as.vector(mlm(Fi, betas))) + Omega
|
||||||
|
tensorPredictors::ising_sample(1, params)
|
||||||
|
})
|
||||||
|
dim(X) <- c(dimX, sample.size)
|
||||||
|
|
||||||
|
list(X = X, F = F, y = y, sample.axis = 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
# # has been run once with initial seed
|
||||||
|
# lpca.hyper.param <- local({
|
||||||
|
# c(X, F, y, sample.axis) %<-% sample.data(1e3, betas, Omegas)
|
||||||
|
# vecX <- mat(X, sample.axis)
|
||||||
|
# CV <- cv.lpca(vecX, ks = prod(dimF), ms = seq(1, 30, by = 1))
|
||||||
|
# # plot(CV)
|
||||||
|
# as.numeric(colnames(CV))[which.min(CV)]
|
||||||
|
# })
|
||||||
|
# set.seed(seed, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
lpca.hyper.param <- 26
|
||||||
|
|
||||||
|
|
||||||
|
# Create a CSV logger to write simulation results to
|
||||||
|
log.file <- paste(base.name, "csv", sep = ".")
|
||||||
|
logger <- CSV.logger(
|
||||||
|
file.name = log.file,
|
||||||
|
header = c("sample.size", "rep", outer(
|
||||||
|
c("time", "dist.subspace"), # < measures, v methods
|
||||||
|
c("gmlm", "tnormal", "pca", "hopca", "lpca", "clpca", "tsir", "mgcca"),
|
||||||
|
paste, sep = "."
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### for each sample size
|
||||||
|
for (sample.size in sample.sizes) {
|
||||||
|
# repeate every simulation
|
||||||
|
for (rep in seq_len(reps)) {
|
||||||
|
# Sample training data
|
||||||
|
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
|
# start timing for reporting
|
||||||
|
start.timer()
|
||||||
|
|
||||||
|
# fit different models
|
||||||
|
# Wrapped in try-catch clock to ensure the simulation continues,
|
||||||
|
# if an error occures continue with nest resplication and log an error message
|
||||||
|
try.catch.block <- tryCatch({
|
||||||
|
time.gmlm <- system.time(
|
||||||
|
fit.gmlm <- gmlm_ising(X, F, y, sample.axis = sample.axis,
|
||||||
|
proj.betas = proj.betas)
|
||||||
|
)["user.self"]
|
||||||
|
time.tnormal <- -1 # part of Ising gmlm (not relevent here)
|
||||||
|
time.pca <- system.time(
|
||||||
|
fit.pca <- prcomp(mat(X, sample.axis), rank. = 1L)
|
||||||
|
)["user.self"]
|
||||||
|
time.hopca <- system.time(
|
||||||
|
fit.hopca <- HOPCA(X, npc = c(1L, 1L), sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
time.lpca <- system.time(
|
||||||
|
fit.lpca <- logisticPCA(mat(X, sample.axis), k = 1L,
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.clpca <- system.time(
|
||||||
|
fit.clpca <- convexLogisticPCA(mat(X, sample.axis), k = 1L,
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.tsir <- system.time(
|
||||||
|
fit.tsir <- TSIR(X, y, d = c(1L, 1L), sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
# `mgcca` expects the first axis to be the sample axis
|
||||||
|
X1 <- aperm(X, c(sample.axis, seq_along(dim(X))[-sample.axis]))
|
||||||
|
F1 <- cbind(sinpi(y), cospi(y))
|
||||||
|
time.mgcca <- system.time(
|
||||||
|
fit.mgcca <- mgcca(list(X1, F1), ncomp = c(1L, 1L),
|
||||||
|
quiet = TRUE, scheme = "factorial")
|
||||||
|
)["user.self"]
|
||||||
|
}, error = print)
|
||||||
|
|
||||||
|
# Get elapsed time from last timer start and the accumulated total time
|
||||||
|
# (_not_ a precide timing, only to get an idea)
|
||||||
|
c(elapsed, total.time) %<-% end.timer()
|
||||||
|
|
||||||
|
# Drop comparison in case any error (in any fitting routine)
|
||||||
|
if (inherits(try.catch.block, "error")) { next }
|
||||||
|
|
||||||
|
# Compute true reduction matrix
|
||||||
|
B.gmlm <- qr.Q(qr(with(fit.gmlm, Reduce(`%x%`, rev(betas)))))[, 1L, drop = FALSE]
|
||||||
|
B.tnormal <- qr.Q(qr(with(attr(fit.gmlm, "tensor_normal"), Reduce(`%x%`, rev(betas)))))[, 1L, drop = FALSE]
|
||||||
|
B.pca <- fit.pca$rotation
|
||||||
|
B.hopca <- Reduce(`%x%`, rev(fit.hopca))
|
||||||
|
B.lpca <- fit.lpca$U
|
||||||
|
B.clpca <- fit.clpca$U
|
||||||
|
B.tsir <- Reduce(`%x%`, rev(fit.tsir))
|
||||||
|
B.mgcca <- fit.mgcca$astar[[1]]
|
||||||
|
|
||||||
|
# Subspace Distances: Normalized `|| P_A - P_B ||_F` where
|
||||||
|
# `P_A = A (A' A)^-1 A'` and the normalization means that with
|
||||||
|
# respect to the dimensions of `A, B` the subspace distance is in the
|
||||||
|
# range `[0, 1]`.
|
||||||
|
dist.subspace.gmlm <- dist.subspace(B.true, B.gmlm, normalize = TRUE)
|
||||||
|
dist.subspace.tnormal <- dist.subspace(B.true, B.tnormal, normalize = TRUE)
|
||||||
|
dist.subspace.pca <- dist.subspace(B.true, B.pca, normalize = TRUE)
|
||||||
|
dist.subspace.hopca <- dist.subspace(B.true, B.hopca, normalize = TRUE)
|
||||||
|
dist.subspace.lpca <- dist.subspace(B.true, B.lpca, normalize = TRUE)
|
||||||
|
dist.subspace.clpca <- dist.subspace(B.true, B.clpca, normalize = TRUE)
|
||||||
|
dist.subspace.tsir <- dist.subspace(B.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.mgcca <- dist.subspace(B.true, B.mgcca, normalize = TRUE)
|
||||||
|
|
||||||
|
# Call CSV logger writing results to file
|
||||||
|
logger()
|
||||||
|
|
||||||
|
# print progress
|
||||||
|
cat(sprintf("sample size (%d): %d/%d - rep: %d/%d - elapsed: %.1f [s], total: %.0f [s]\n",
|
||||||
|
sample.size, which(sample.size == sample.sizes),
|
||||||
|
length(sample.sizes), rep, reps, elapsed, total.time))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### read simulation results generate plots
|
||||||
|
if (!interactive()) { pdf(file = paste(base.name, "pdf", sep = ".")) }
|
||||||
|
|
||||||
|
sim <- read.csv(log.file)
|
||||||
|
|
||||||
|
|
||||||
|
plot.sim(sim, "dist.subspace", main = "Subspace Distance",
|
||||||
|
xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
# plot.sim(sim, "dist.projection", main = "Projection Distance",
|
||||||
|
# xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
plot.sim(sim, "time", main = "Runtime",
|
||||||
|
xlab = "Sample Size", ylab = "Time [s]",
|
||||||
|
ylim = c(0, max(sim[startsWith(names(sim), "time")])))
|
|
@ -0,0 +1,213 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
library(logisticPCA)
|
||||||
|
# library(RGCCA)
|
||||||
|
# Use modified version of `RGCCA`
|
||||||
|
# Reasons (on Ubuntu 22.04 LTS):
|
||||||
|
# - compatible with `Rscript`
|
||||||
|
# - about 4 times faster for small problems
|
||||||
|
# Changes:
|
||||||
|
# - Run in main thread, avoid `parallel::makeCluster` if `n_cores == 1`
|
||||||
|
# (file "R/mgccak.R" lines 81:103)
|
||||||
|
# - added `Encoding: UTF-8`
|
||||||
|
# (file "DESCRIPTION")
|
||||||
|
suppressWarnings({
|
||||||
|
devtools::load_all("~/Work/tensorPredictors/References/Software/TGCCA-modified", export_all = FALSE)
|
||||||
|
})
|
||||||
|
|
||||||
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
base.name <- format(Sys.time(), "sim_2d_ising-%Y%m%dT%H%M")
|
||||||
|
|
||||||
|
# Source utility function used in most simulations (extracted for convenience)
|
||||||
|
source("./sim_utils.R")
|
||||||
|
|
||||||
|
# Set PRNG seed for reproducability
|
||||||
|
# Note: `0x` is the HEX number prefix and the trailing `L` stands for "long"
|
||||||
|
# which is `R`s way if indicating an integer.
|
||||||
|
set.seed(seed <- 0x2dL, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
|
||||||
|
|
||||||
|
reps <- 100 # number of simulation replications
|
||||||
|
sample.sizes <- c(100, 200, 300, 500, 750) # sample sizes `n`
|
||||||
|
dimX <- c(2, 3) # predictor `X` dimension
|
||||||
|
dimF <- c(2, 2) # "function" `F(y)` of responce `y` dimension
|
||||||
|
|
||||||
|
betas <- Map(diag, 1, dimX, dimF)
|
||||||
|
# # All identical couplings with log odds of 1, that is approx
|
||||||
|
# # `P(X_i = 1, X_j = 1 | X_-i,-j = 0) ~ 3 / 4`
|
||||||
|
# Omegas <- Map(function(dim) `diag<-`(matrix(1, dim, dim), 0), dimX)
|
||||||
|
Omegas <- list(
|
||||||
|
`diag<-`(matrix(0.5, dimX[1], dimX[1]), 0),
|
||||||
|
diag(dimX[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute true (full) model parameters to compair estimates against
|
||||||
|
B.true <- Reduce(`%x%`, rev(betas))
|
||||||
|
|
||||||
|
# Build projections onto `all elements are equal except diagonal is zero` matrices
|
||||||
|
# proj.Omegas <- Map(function(Omega) {
|
||||||
|
# proj <- as.vector(Omega) %*% pinv(as.vector(Omega))
|
||||||
|
# function(Omega) {
|
||||||
|
# matrix(proj %*% as.vector(Omega), nrow = nrow(Omega))
|
||||||
|
# }
|
||||||
|
# }, Omegas)
|
||||||
|
proj.Omegas <- Map(.projMaskedMean, Map(as.logical, Omegas))
|
||||||
|
|
||||||
|
# data sampling routine
|
||||||
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
|
dimX <- mapply(nrow, betas)
|
||||||
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
|
# generate response (sample axis is last axis)
|
||||||
|
y <- runif(sample.size, -1, 1)
|
||||||
|
F <- aperm(array(c(
|
||||||
|
sinpi(y), -cospi(y),
|
||||||
|
cospi(y), sinpi(y)
|
||||||
|
), dim = c(length(y), 2, 2)), c(2, 3, 1))
|
||||||
|
|
||||||
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
X <- apply(F, 3, function(Fi) {
|
||||||
|
dim(Fi) <- dimF
|
||||||
|
params <- diag(as.vector(mlm(Fi, betas))) + Omega
|
||||||
|
tensorPredictors::ising_sample(1, params)
|
||||||
|
})
|
||||||
|
dim(X) <- c(dimX, sample.size)
|
||||||
|
|
||||||
|
list(X = X, F = F, y = y, sample.axis = 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
# # has been run once with initial seed
|
||||||
|
# lpca.hyper.param <- local({
|
||||||
|
# c(X, F, y, sample.axis) %<-% sample.data(1e3, betas, Omegas)
|
||||||
|
# vecX <- mat(X, sample.axis)
|
||||||
|
# CV <- cv.lpca(vecX, ks = prod(dimF), ms = seq(1, 30, by = 1))
|
||||||
|
# # plot(CV)
|
||||||
|
# as.numeric(colnames(CV))[which.min(CV)]
|
||||||
|
# })
|
||||||
|
# set.seed(seed, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
lpca.hyper.param <- 10
|
||||||
|
|
||||||
|
|
||||||
|
# Create a CSV logger to write simulation results to
|
||||||
|
log.file <- paste(base.name, "csv", sep = ".")
|
||||||
|
logger <- CSV.logger(
|
||||||
|
file.name = log.file,
|
||||||
|
header = c("sample.size", "rep", outer(
|
||||||
|
c("time", "dist.subspace", "dist.projection"), # < measures, v methods
|
||||||
|
c("gmlm", "tnormal", "pca", "hopca", "lpca", "clpca", "tsir", "mgcca"),
|
||||||
|
paste, sep = "."
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### for each sample size
|
||||||
|
for (sample.size in sample.sizes) {
|
||||||
|
# repeate every simulation
|
||||||
|
for (rep in seq_len(reps)) {
|
||||||
|
# Sample training data
|
||||||
|
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
|
# start timing for reporting
|
||||||
|
start.timer()
|
||||||
|
|
||||||
|
# fit different models
|
||||||
|
# Wrapped in try-catch clock to ensure the simulation continues,
|
||||||
|
# if an error occures continue with nest resplication and log an error message
|
||||||
|
try.catch.block <- tryCatch({
|
||||||
|
time.gmlm <- system.time(
|
||||||
|
fit.gmlm <- gmlm_ising(X, F, y, sample.axis = sample.axis,
|
||||||
|
proj.Omegas = proj.Omegas)
|
||||||
|
)["user.self"]
|
||||||
|
time.tnormal <- -1 # part of Ising gmlm (not relevent here)
|
||||||
|
time.pca <- system.time(
|
||||||
|
fit.pca <- prcomp(mat(X, sample.axis), rank. = prod(dimF))
|
||||||
|
)["user.self"]
|
||||||
|
time.hopca <- system.time(
|
||||||
|
fit.hopca <- HOPCA(X, npc = dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
time.lpca <- system.time(
|
||||||
|
fit.lpca <- logisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.clpca <- system.time(
|
||||||
|
fit.clpca <- convexLogisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.tsir <- system.time(
|
||||||
|
fit.tsir <- TSIR(X, y, d = dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
# `mgcca` expects the first axis to be the sample axis
|
||||||
|
X1 <- aperm(X, c(sample.axis, seq_along(dim(X))[-sample.axis]))
|
||||||
|
F1 <- cbind(sinpi(y), cospi(y))
|
||||||
|
time.mgcca <- system.time(
|
||||||
|
fit.mgcca <- mgcca(list(X1, F1), ncomp = c(prod(dimF), 1L),
|
||||||
|
quiet = TRUE, scheme = "factorial")
|
||||||
|
)["user.self"]
|
||||||
|
}, error = print)
|
||||||
|
|
||||||
|
# Get elapsed time from last timer start and the accumulated total time
|
||||||
|
# (_not_ a precide timing, only to get an idea)
|
||||||
|
c(elapsed, total.time) %<-% end.timer()
|
||||||
|
|
||||||
|
# Drop comparison in case any error (in any fitting routine)
|
||||||
|
if (inherits(try.catch.block, "error")) { next }
|
||||||
|
|
||||||
|
# Compute true reduction matrix
|
||||||
|
B.gmlm <- with(fit.gmlm, Reduce(`%x%`, rev(betas)))
|
||||||
|
B.tnormal <- with(attr(fit.gmlm, "tensor_normal"), Reduce(`%x%`, rev(betas)))
|
||||||
|
B.pca <- fit.pca$rotation
|
||||||
|
B.hopca <- Reduce(`%x%`, rev(fit.hopca))
|
||||||
|
B.lpca <- fit.lpca$U
|
||||||
|
B.clpca <- fit.clpca$U
|
||||||
|
B.tsir <- Reduce(`%x%`, rev(fit.tsir))
|
||||||
|
B.mgcca <- fit.mgcca$astar[[1]]
|
||||||
|
|
||||||
|
# Subspace Distances: Normalized `|| P_A - P_B ||_F` where
|
||||||
|
# `P_A = A (A' A)^-1 A'` and the normalization means that with
|
||||||
|
# respect to the dimensions of `A, B` the subspace distance is in the
|
||||||
|
# range `[0, 1]`.
|
||||||
|
dist.subspace.gmlm <- dist.subspace(B.true, B.gmlm, normalize = TRUE)
|
||||||
|
dist.subspace.tnormal <- dist.subspace(B.true, B.tnormal, normalize = TRUE)
|
||||||
|
dist.subspace.pca <- dist.subspace(B.true, B.pca, normalize = TRUE)
|
||||||
|
dist.subspace.hopca <- dist.subspace(B.true, B.hopca, normalize = TRUE)
|
||||||
|
dist.subspace.lpca <- dist.subspace(B.true, B.lpca, normalize = TRUE)
|
||||||
|
dist.subspace.clpca <- dist.subspace(B.true, B.clpca, normalize = TRUE)
|
||||||
|
dist.subspace.tsir <- dist.subspace(B.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.mgcca <- dist.subspace(B.true, B.mgcca, normalize = TRUE)
|
||||||
|
|
||||||
|
# Projection Distances: Spectral norm (2-norm) `|| P_A - P_B ||_2`.
|
||||||
|
dist.projection.gmlm <- dist.projection(B.true, B.gmlm)
|
||||||
|
dist.projection.tnormal <- dist.projection(B.true, B.tnormal)
|
||||||
|
dist.projection.pca <- dist.projection(B.true, B.pca)
|
||||||
|
dist.projection.hopca <- dist.projection(B.true, B.hopca)
|
||||||
|
dist.projection.lpca <- dist.projection(B.true, B.lpca)
|
||||||
|
dist.projection.clpca <- dist.projection(B.true, B.clpca)
|
||||||
|
dist.projection.tsir <- dist.projection(B.true, B.tsir)
|
||||||
|
dist.projection.mgcca <- dist.projection(B.true, B.mgcca)
|
||||||
|
|
||||||
|
# Call CSV logger writing results to file
|
||||||
|
logger()
|
||||||
|
|
||||||
|
# print progress
|
||||||
|
cat(sprintf("sample size (%d): %d/%d - rep: %d/%d - elapsed: %.1f [s], total: %.0f [s]\n",
|
||||||
|
sample.size, which(sample.size == sample.sizes),
|
||||||
|
length(sample.sizes), rep, reps, elapsed, total.time))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### read simulation results generate plots
|
||||||
|
if (!interactive()) { pdf(file = paste(base.name, "pdf", sep = ".")) }
|
||||||
|
|
||||||
|
sim <- read.csv(log.file)
|
||||||
|
|
||||||
|
|
||||||
|
plot.sim(sim, "dist.subspace", main = "Subspace Distance",
|
||||||
|
xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
# plot.sim(sim, "dist.projection", main = "Projection Distance",
|
||||||
|
# xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
plot.sim(sim, "time", main = "Runtime",
|
||||||
|
xlab = "Sample Size", ylab = "Time [s]",
|
||||||
|
ylim = c(0, max(sim[startsWith(names(sim), "time")])))
|
|
@ -0,0 +1,244 @@
|
||||||
|
library(tensorPredictors)
|
||||||
|
# library(logisticPCA)
|
||||||
|
# # library(RGCCA)
|
||||||
|
# # Use modified version of `RGCCA`
|
||||||
|
# # Reasons (on Ubuntu 22.04 LTS):
|
||||||
|
# # - compatible with `Rscript`
|
||||||
|
# # - about 4 times faster for small problems
|
||||||
|
# # Changes:
|
||||||
|
# # - Run in main thread, avoid `parallel::makeCluster` if `n_cores == 1`
|
||||||
|
# # (file "R/mgccak.R" lines 81:103)
|
||||||
|
# # - added `Encoding: UTF-8`
|
||||||
|
# # (file "DESCRIPTION")
|
||||||
|
# suppressWarnings({
|
||||||
|
# devtools::load_all("~/Work/tensorPredictors/References/Software/TGCCA-modified", export_all = FALSE)
|
||||||
|
# })
|
||||||
|
|
||||||
|
# setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
# base.name <- format(Sys.time(), "sim_2e_ising-%Y%m%dT%H%M")
|
||||||
|
|
||||||
|
# # Source utility function used in most simulations (extracted for convenience)
|
||||||
|
# source("./sim_utils.R")
|
||||||
|
|
||||||
|
# # Set PRNG seed for reproducability
|
||||||
|
# # Note: `0x` is the HEX number prefix and the trailing `L` stands for "long"
|
||||||
|
# # which is `R`s way if indicating an integer.
|
||||||
|
# set.seed(seed <- 0x2eL, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
|
||||||
|
|
||||||
|
reps <- 100 # number of simulation replications
|
||||||
|
sample.sizes <- c(100, 200, 300, 500, 750) # sample sizes `n`
|
||||||
|
dimX <- c(5, 5, 5) # predictor `X` dimension
|
||||||
|
dimF <- c(2, 2, 2) # "function" `F(y)` of responce `y` dimension
|
||||||
|
|
||||||
|
betas <- Map(matrix, 1, dimX, dimF)
|
||||||
|
|
||||||
|
Omegas <- Map(function(p) `diag<-`(matrix(0.5, p, p), 0), dimX)
|
||||||
|
|
||||||
|
|
||||||
|
# data sampling routine
|
||||||
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
|
dimX <- mapply(nrow, betas)
|
||||||
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
|
# generate response (sample axis is last axis)
|
||||||
|
y <- runif(sample.size, -1, 1)
|
||||||
|
F <- aperm(array(c(
|
||||||
|
+sinpi(y), +sinpi(2 * y),
|
||||||
|
+cospi(y), +cospi(2 * y),
|
||||||
|
-cospi(y), -cospi(2 * y),
|
||||||
|
+sinpi(y), +sinpi(2 * y)
|
||||||
|
), dim = c(length(y), 2, 2, 2)), c(2, 3, 4, 1))
|
||||||
|
|
||||||
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
X <- apply(F, length(dim(F)), function(Fi) {
|
||||||
|
dim(Fi) <- dimF
|
||||||
|
params <- diag(as.vector(mlm(Fi, betas))) + Omega
|
||||||
|
tensorPredictors::ising_sample(1, params)
|
||||||
|
})
|
||||||
|
dim(X) <- c(dimX, sample.size)
|
||||||
|
|
||||||
|
list(X = X, F = F, y = y, sample.axis = 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
sample.size <- 100L
|
||||||
|
|
||||||
|
# Sample training data
|
||||||
|
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# compute true (full) model parameters to compair estimates against
|
||||||
|
B.true <- Reduce(`%x%`, rev(betas))
|
||||||
|
|
||||||
|
# Build projections onto `all elements are equal except diagonal is zero` matrices
|
||||||
|
# proj.Omegas <- Map(function(Omega) {
|
||||||
|
# proj <- as.vector(Omega) %*% pinv(as.vector(Omega))
|
||||||
|
# function(Omega) {
|
||||||
|
# matrix(proj %*% as.vector(Omega), nrow = nrow(Omega))
|
||||||
|
# }
|
||||||
|
# }, Omegas)
|
||||||
|
proj.Omegas <- Map(.projMaskedMean, Map(as.logical, Omegas))
|
||||||
|
|
||||||
|
# data sampling routine
|
||||||
|
sample.data <- function(sample.size, betas, Omegas) {
|
||||||
|
dimX <- mapply(nrow, betas)
|
||||||
|
dimF <- mapply(ncol, betas)
|
||||||
|
|
||||||
|
# generate response (sample axis is last axis)
|
||||||
|
y <- runif(sample.size, -1, 1)
|
||||||
|
F <- aperm(array(c(
|
||||||
|
sinpi(y), -cospi(y),
|
||||||
|
cospi(y), sinpi(y)
|
||||||
|
), dim = c(length(y), 2, 2)), c(2, 3, 1))
|
||||||
|
|
||||||
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
X <- apply(F, 3, function(Fi) {
|
||||||
|
dim(Fi) <- dimF
|
||||||
|
params <- diag(as.vector(mlm(Fi, betas))) + Omega
|
||||||
|
tensorPredictors::ising_sample(1, params)
|
||||||
|
})
|
||||||
|
dim(X) <- c(dimX, sample.size)
|
||||||
|
|
||||||
|
list(X = X, F = F, y = y, sample.axis = 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
# # has been run once with initial seed
|
||||||
|
# lpca.hyper.param <- local({
|
||||||
|
# c(X, F, y, sample.axis) %<-% sample.data(1e3, betas, Omegas)
|
||||||
|
# vecX <- mat(X, sample.axis)
|
||||||
|
# CV <- cv.lpca(vecX, ks = prod(dimF), ms = seq(1, 30, by = 1))
|
||||||
|
# # plot(CV)
|
||||||
|
# as.numeric(colnames(CV))[which.min(CV)]
|
||||||
|
# })
|
||||||
|
# set.seed(seed, "Mersenne-Twister", "Inversion", "Rejection")
|
||||||
|
lpca.hyper.param <- 10
|
||||||
|
|
||||||
|
|
||||||
|
# Create a CSV logger to write simulation results to
|
||||||
|
log.file <- paste(base.name, "csv", sep = ".")
|
||||||
|
logger <- CSV.logger(
|
||||||
|
file.name = log.file,
|
||||||
|
header = c("sample.size", "rep", outer(
|
||||||
|
c("time", "dist.subspace", "dist.projection"), # < measures, v methods
|
||||||
|
c("gmlm", "tnormal", "pca", "hopca", "lpca", "clpca", "tsir", "mgcca"),
|
||||||
|
paste, sep = "."
|
||||||
|
))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
### for each sample size
|
||||||
|
for (sample.size in sample.sizes) {
|
||||||
|
# repeate every simulation
|
||||||
|
for (rep in seq_len(reps)) {
|
||||||
|
# Sample training data
|
||||||
|
c(X, F, y, sample.axis) %<-% sample.data(sample.size, betas, Omegas)
|
||||||
|
|
||||||
|
# start timing for reporting
|
||||||
|
start.timer()
|
||||||
|
|
||||||
|
# fit different models
|
||||||
|
# Wrapped in try-catch clock to ensure the simulation continues,
|
||||||
|
# if an error occures continue with nest resplication and log an error message
|
||||||
|
try.catch.block <- tryCatch({
|
||||||
|
time.gmlm <- system.time(
|
||||||
|
fit.gmlm <- gmlm_ising(X, F, y, sample.axis = sample.axis,
|
||||||
|
proj.Omegas = proj.Omegas)
|
||||||
|
)["user.self"]
|
||||||
|
time.tnormal <- -1 # part of Ising gmlm (not relevent here)
|
||||||
|
time.pca <- system.time(
|
||||||
|
fit.pca <- prcomp(mat(X, sample.axis), rank. = prod(dimF))
|
||||||
|
)["user.self"]
|
||||||
|
time.hopca <- system.time(
|
||||||
|
fit.hopca <- HOPCA(X, npc = dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
time.lpca <- system.time(
|
||||||
|
fit.lpca <- logisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.clpca <- system.time(
|
||||||
|
fit.clpca <- convexLogisticPCA(mat(X, sample.axis), k = prod(dimF),
|
||||||
|
m = lpca.hyper.param)
|
||||||
|
)["user.self"]
|
||||||
|
time.tsir <- system.time(
|
||||||
|
fit.tsir <- TSIR(X, y, d = dimF, sample.axis = sample.axis)
|
||||||
|
)["user.self"]
|
||||||
|
# `mgcca` expects the first axis to be the sample axis
|
||||||
|
X1 <- aperm(X, c(sample.axis, seq_along(dim(X))[-sample.axis]))
|
||||||
|
F1 <- cbind(sinpi(y), cospi(y))
|
||||||
|
time.mgcca <- system.time(
|
||||||
|
fit.mgcca <- mgcca(list(X1, F1), ncomp = c(prod(dimF), 1L),
|
||||||
|
quiet = TRUE, scheme = "factorial")
|
||||||
|
)["user.self"]
|
||||||
|
}, error = print)
|
||||||
|
|
||||||
|
# Get elapsed time from last timer start and the accumulated total time
|
||||||
|
# (_not_ a precide timing, only to get an idea)
|
||||||
|
c(elapsed, total.time) %<-% end.timer()
|
||||||
|
|
||||||
|
# Drop comparison in case any error (in any fitting routine)
|
||||||
|
if (inherits(try.catch.block, "error")) { next }
|
||||||
|
|
||||||
|
# Compute true reduction matrix
|
||||||
|
B.gmlm <- with(fit.gmlm, Reduce(`%x%`, rev(betas)))
|
||||||
|
B.tnormal <- with(attr(fit.gmlm, "tensor_normal"), Reduce(`%x%`, rev(betas)))
|
||||||
|
B.pca <- fit.pca$rotation
|
||||||
|
B.hopca <- Reduce(`%x%`, rev(fit.hopca))
|
||||||
|
B.lpca <- fit.lpca$U
|
||||||
|
B.clpca <- fit.clpca$U
|
||||||
|
B.tsir <- Reduce(`%x%`, rev(fit.tsir))
|
||||||
|
B.mgcca <- fit.mgcca$astar[[1]]
|
||||||
|
|
||||||
|
# Subspace Distances: Normalized `|| P_A - P_B ||_F` where
|
||||||
|
# `P_A = A (A' A)^-1 A'` and the normalization means that with
|
||||||
|
# respect to the dimensions of `A, B` the subspace distance is in the
|
||||||
|
# range `[0, 1]`.
|
||||||
|
dist.subspace.gmlm <- dist.subspace(B.true, B.gmlm, normalize = TRUE)
|
||||||
|
dist.subspace.tnormal <- dist.subspace(B.true, B.tnormal, normalize = TRUE)
|
||||||
|
dist.subspace.pca <- dist.subspace(B.true, B.pca, normalize = TRUE)
|
||||||
|
dist.subspace.hopca <- dist.subspace(B.true, B.hopca, normalize = TRUE)
|
||||||
|
dist.subspace.lpca <- dist.subspace(B.true, B.lpca, normalize = TRUE)
|
||||||
|
dist.subspace.clpca <- dist.subspace(B.true, B.clpca, normalize = TRUE)
|
||||||
|
dist.subspace.tsir <- dist.subspace(B.true, B.tsir, normalize = TRUE)
|
||||||
|
dist.subspace.mgcca <- dist.subspace(B.true, B.mgcca, normalize = TRUE)
|
||||||
|
|
||||||
|
# Projection Distances: Spectral norm (2-norm) `|| P_A - P_B ||_2`.
|
||||||
|
dist.projection.gmlm <- dist.projection(B.true, B.gmlm)
|
||||||
|
dist.projection.tnormal <- dist.projection(B.true, B.tnormal)
|
||||||
|
dist.projection.pca <- dist.projection(B.true, B.pca)
|
||||||
|
dist.projection.hopca <- dist.projection(B.true, B.hopca)
|
||||||
|
dist.projection.lpca <- dist.projection(B.true, B.lpca)
|
||||||
|
dist.projection.clpca <- dist.projection(B.true, B.clpca)
|
||||||
|
dist.projection.tsir <- dist.projection(B.true, B.tsir)
|
||||||
|
dist.projection.mgcca <- dist.projection(B.true, B.mgcca)
|
||||||
|
|
||||||
|
# Call CSV logger writing results to file
|
||||||
|
logger()
|
||||||
|
|
||||||
|
# print progress
|
||||||
|
cat(sprintf("sample size (%d): %d/%d - rep: %d/%d - elapsed: %.1f [s], total: %.0f [s]\n",
|
||||||
|
sample.size, which(sample.size == sample.sizes),
|
||||||
|
length(sample.sizes), rep, reps, elapsed, total.time))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### read simulation results generate plots
|
||||||
|
if (!interactive()) { pdf(file = paste(base.name, "pdf", sep = ".")) }
|
||||||
|
|
||||||
|
sim <- read.csv(log.file)
|
||||||
|
|
||||||
|
|
||||||
|
plot.sim(sim, "dist.subspace", main = "Subspace Distance",
|
||||||
|
xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
# plot.sim(sim, "dist.projection", main = "Projection Distance",
|
||||||
|
# xlab = "Sample Size", ylab = "Distance")
|
||||||
|
|
||||||
|
plot.sim(sim, "time", main = "Runtime",
|
||||||
|
xlab = "Sample Size", ylab = "Time [s]",
|
||||||
|
ylim = c(0, max(sim[startsWith(names(sim), "time")])))
|
|
@ -0,0 +1,610 @@
|
||||||
|
library(microbenchmark)
|
||||||
|
library(tensorPredictors)
|
||||||
|
|
||||||
|
setwd("~/Work/tensorPredictors/sim/")
|
||||||
|
base.name <- "sim_ising_perft"
|
||||||
|
|
||||||
|
# Number of replications, sufficient for the performance test
|
||||||
|
reps <- 5
|
||||||
|
|
||||||
|
# Sets the dimensions to be tested for runtime per method
|
||||||
|
configs <- list(
|
||||||
|
exact = list( # Exact method
|
||||||
|
dim = 1:24,
|
||||||
|
use_MC = FALSE,
|
||||||
|
nr_threads = 1L # ignored in this case, but no special case neded
|
||||||
|
),
|
||||||
|
MC = list( # Monte-Carlo Estimate
|
||||||
|
dim = c(1:20, (3:13) * 10),
|
||||||
|
use_MC = TRUE,
|
||||||
|
nr_threads = 1L # default nr. of threads
|
||||||
|
),
|
||||||
|
MC8 = list( # Monte-Carlo Estimate using 8 threads
|
||||||
|
dim = c(1:20, (3:13) * 10),
|
||||||
|
use_MC = TRUE,
|
||||||
|
nr_threads = 8L # my machines nr of (virtual) cores
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple function creating a parameter vector to be passed to `ising_m2`, the values
|
||||||
|
# are irrelevant while its own execution time is (more or less) neglectable
|
||||||
|
params <- function(dim) double(dim * (dim + 1L) / 2L)
|
||||||
|
|
||||||
|
# Build expressions to be past to `microbenchmark` for performance testing
|
||||||
|
expressions <- Reduce(c, Map(function(method) {
|
||||||
|
config <- configs[[method]]
|
||||||
|
|
||||||
|
Map(function(dim) {
|
||||||
|
as.call(list(quote(ising_m2), params = substitute(params(dim), list(dim = dim)),
|
||||||
|
use_MC = config$use_MC, nr_threads = config$nr_threads))
|
||||||
|
}, config$dim)
|
||||||
|
}, names(configs)))
|
||||||
|
|
||||||
|
# Performance tests
|
||||||
|
perft.results <- microbenchmark(list = expressions, times = reps)
|
||||||
|
|
||||||
|
# Convert performance test results to data frame for further processing
|
||||||
|
(perft <- summary(perft.results))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Ploting the performance simulation
|
||||||
|
################################################################################
|
||||||
|
### TODO: Fix plotting ###
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
if (FALSE) {
|
||||||
|
|
||||||
|
with(sim, {
|
||||||
|
par(mfrow = c(2, 2), mar = c(5, 4, 4, 4) + 0.1)
|
||||||
|
|
||||||
|
# Effect of Nr. of samples
|
||||||
|
plot(range(nr_samples), range(mse - sd, mse + sd),
|
||||||
|
type = "n", bty = "n", log = "xy", yaxt = "n",
|
||||||
|
xlab = "Nr. Samples", ylab = "MSE",
|
||||||
|
main = "Sample Size Effect (MSE)")
|
||||||
|
groups <- split(sim, warmup)
|
||||||
|
for (i in seq_along(groups)) {
|
||||||
|
with(groups[[i]], {
|
||||||
|
lines(nr_samples, mse, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
lines(nr_samples, mse - sd, col = i, lty = 2)
|
||||||
|
lines(nr_samples, mse + sd, col = i, lty = 2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
right <- nr_samples == max(nr_samples)
|
||||||
|
axis(4, at = mse[right], labels = warmup[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
mtext("Warmup", side = 4, line = 1.5, at = exp(mean(range(log(mse[right])))))
|
||||||
|
y.pow <- -10:-1
|
||||||
|
axis(2, at = c(1, 10^y.pow),
|
||||||
|
labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# Effect warmup length
|
||||||
|
plot(range(warmup + 1), range(mse - sd, mse + sd),
|
||||||
|
type = "n", bty = "n", log = "xy", xaxt = "n", yaxt = "n",
|
||||||
|
xlab = "Warmup", ylab = "MSE",
|
||||||
|
main = "Warmup Effect (MSE)")
|
||||||
|
axis(1, warmup + 1, labels = as.integer(warmup))
|
||||||
|
groups <- split(sim, nr_samples)
|
||||||
|
for (i in seq_along(groups)) {
|
||||||
|
with(groups[[i]], {
|
||||||
|
lines(warmup + 1, mse, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
lines(warmup + 1, mse - sd, col = i, lty = 2)
|
||||||
|
lines(warmup + 1, mse + sd, col = i, lty = 2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
right <- warmup == max(warmup)
|
||||||
|
axis(4, at = mse[right], labels = nr_samples[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
mtext("Nr. Samples", side = 4, line = 1.5, at = exp(mean(range(log(mse[right])))))
|
||||||
|
axis(2, at = c(1, 10^y.pow),
|
||||||
|
labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# Effect of Nr. of samples
|
||||||
|
plot(range(nr_samples), range(merr),
|
||||||
|
type = "n", bty = "n", log = "xy", yaxt = "n",
|
||||||
|
xlab = "Nr. Samples", ylab = "Max Abs Error Mean",
|
||||||
|
main = "Sample Size Effect (Abs Error)")
|
||||||
|
groups <- split(sim, warmup)
|
||||||
|
for (i in seq_along(groups)) {
|
||||||
|
with(groups[[i]], {
|
||||||
|
lines(nr_samples, merr, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
right <- nr_samples == max(nr_samples)
|
||||||
|
axis(4, at = merr[right], labels = warmup[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
mtext("Warmup", side = 4, line = 1.5, at = exp(mean(range(log(merr[right])))))
|
||||||
|
y.pow <- -10:-1
|
||||||
|
axis(2, at = c(1, 10^y.pow),
|
||||||
|
labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# Effect of warmup length
|
||||||
|
plot(range(warmup + 1), range(merr),
|
||||||
|
type = "n", bty = "n", log = "xy", xaxt = "n", yaxt = "n",
|
||||||
|
xlab = "Warmup", ylab = "Max Abs Error Mean",
|
||||||
|
main = "Warmup Effect (Abs Error)")
|
||||||
|
axis(1, warmup + 1, labels = as.integer(warmup))
|
||||||
|
groups <- split(sim, nr_samples)
|
||||||
|
for (i in seq_along(groups)) {
|
||||||
|
with(groups[[i]], {
|
||||||
|
lines(warmup + 1, merr, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
right <- warmup == max(warmup)
|
||||||
|
axis(4, at = merr[right], labels = nr_samples[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
mtext("Nr. Samples", side = 4, line = 1.5, at = exp(mean(range(log(merr[right])))))
|
||||||
|
axis(2, at = c(1, 10^y.pow),
|
||||||
|
labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add common title
|
||||||
|
mtext(main, side = 3, line = -2, outer = TRUE, font = 2, cex = 1.5)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# test_unscaled_prob <- function() {
|
||||||
|
# test <- function(p) {
|
||||||
|
# y <- sample.int(2, p, replace = TRUE) - 1L
|
||||||
|
# theta <- vech(matrix(rnorm(p^2), p))
|
||||||
|
|
||||||
|
|
||||||
|
# C <- ising_m2(y, theta)
|
||||||
|
# R <- exp(sum(vech(outer(y, y)) * theta))
|
||||||
|
|
||||||
|
# if (all.equal(C, R) == TRUE) {
|
||||||
|
# cat("\033[92mSUCCESS: ")
|
||||||
|
# } else {
|
||||||
|
# cat("\033[91mFAILED: ")
|
||||||
|
# }
|
||||||
|
# cat(sprintf("p = %d, C = %e, R = %e\n", p, C, R))
|
||||||
|
# cat(" ", paste0(format(seq_along(y) - 1), collapse = " "), "\n")
|
||||||
|
# cat(" y = ", paste0(c(".", "1")[y + 1], collapse = " "))
|
||||||
|
# cat("\033[0m\n\n")
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# devtools::load_all()
|
||||||
|
# for (p in c(1, 10, 30:35, 62:66, 70, 128, 130)) {
|
||||||
|
# test(p)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# test_ising_sample <- function() {
|
||||||
|
# test <- function(p) {
|
||||||
|
# # theta <- vech(matrix(rnorm(p^2), p))
|
||||||
|
# # theta <- vech(matrix(0, p, p))
|
||||||
|
# theta <- -0.01 * vech(1 - diag(p))
|
||||||
|
# # theta <- vech(0.2 * diag(p))
|
||||||
|
|
||||||
|
# sample <- ising_sample(11, theta)
|
||||||
|
|
||||||
|
|
||||||
|
# print.table(sample, zero.print = ".")
|
||||||
|
# print(mean(sample))
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# devtools::load_all()
|
||||||
|
# for (p in c(1, 10, 30:35, 62:66, 70, 128, 130)) {
|
||||||
|
# test(p)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# test_ising_partition_func_MC <- function() {
|
||||||
|
# test <- function(p) {
|
||||||
|
# # theta <- vech(matrix(rnorm(p^2), p))
|
||||||
|
# # theta <- vech(matrix(0, p, p))
|
||||||
|
# theta <- -0.01 * vech(1 - diag(p))
|
||||||
|
|
||||||
|
# time_gmlm <- system.time(val_gmlm <- ising_partition_func_MC(theta))
|
||||||
|
# time_gmlm <- round(1000 * time_gmlm[["elapsed"]])
|
||||||
|
|
||||||
|
# if (p < 21) {
|
||||||
|
# time_mvb <- system.time(val_mvb <- 1 / mvbernoulli::ising_prob0(theta))
|
||||||
|
# time_mvb <- round(1000 * time_mvb[["elapsed"]])
|
||||||
|
# } else {
|
||||||
|
# val_mvb <- NaN
|
||||||
|
# time_mvb <- -1
|
||||||
|
# }
|
||||||
|
|
||||||
|
# cat(sprintf(
|
||||||
|
# "dim = %d\n GMLM: time = %4d ms, val = %.4e\n MVB: time = %4d ms, val = %.4e\n",
|
||||||
|
# p, time_gmlm, val_gmlm, time_mvb, val_mvb))
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# devtools::load_all()
|
||||||
|
|
||||||
|
# system.time(
|
||||||
|
# # for (p in c(1, 10, 20, 30:35, 64, 70, 128, 130)) {
|
||||||
|
# for (p in c(1, 10, 20, 30:35, 64)) {
|
||||||
|
# test(p)
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
# }
|
||||||
|
# # test_ising_partition_func_MC()
|
||||||
|
|
||||||
|
|
||||||
|
# validate_ising_partition_func_MC <- function(theta_func) {
|
||||||
|
# est_var <- function(dim) {
|
||||||
|
# theta <- theta_func(dim)
|
||||||
|
|
||||||
|
# time <- system.time(rep <- replicate(100, ising_partition_func_MC(theta)))
|
||||||
|
|
||||||
|
# cat(sprintf("dim = %d, time = %.2e s, mean = %.2e, std.dev = %.2e\n",
|
||||||
|
# dim, time[["elapsed"]], mean(rep), sd(rep)))
|
||||||
|
# }
|
||||||
|
|
||||||
|
# for (dim in 10 * (1:13)) {
|
||||||
|
# est_var(dim)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
# # validate_ising_partition_func_MC(function(dim) { vech(matrix(rnorm(dim^2), dim)) })
|
||||||
|
# # validate_ising_partition_func_MC(function(dim) { vech(matrix(0, dim, dim)) })
|
||||||
|
# # validate_ising_partition_func_MC(function(dim) { -0.01 * vech(1 - diag(dim)) })
|
||||||
|
# # validate_ising_partition_func_MC(function(dim) { vech(0.2 * diag(dim)) })
|
||||||
|
|
||||||
|
|
||||||
|
# test_ising_partition_func_exact <- function(theta_func) {
|
||||||
|
|
||||||
|
# test <- function(dim) {
|
||||||
|
# theta <- theta_func(dim)
|
||||||
|
|
||||||
|
# reps <- if (dim < 10) 100 else 10
|
||||||
|
|
||||||
|
# time <- system.time(replicate(reps, ising_partition_func_exact(theta)))
|
||||||
|
# time <- time[["elapsed"]] / reps
|
||||||
|
|
||||||
|
# cat(sprintf("dim = %d, time = %.2e s\n", dim, time))
|
||||||
|
# }
|
||||||
|
|
||||||
|
# for (dim in 1:20) {
|
||||||
|
# test(dim)
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
# test_ising_partition_func_exact(function(dim) { vech(matrix(rnorm(dim^2), dim)) })
|
||||||
|
|
||||||
|
|
||||||
|
# ### Performance Measurement/Comparison
|
||||||
|
# local({
|
||||||
|
# perft_exact <- local({
|
||||||
|
# dims <- 2:22
|
||||||
|
|
||||||
|
# cat("Exact perft:\n")
|
||||||
|
# times <- sapply(dims, function(dim) {
|
||||||
|
# reps <- if (dim < 10) 1000 else if (dim < 15) 100 else if (dim < 20) 10 else 4
|
||||||
|
# theta <- vech(matrix(rnorm(dim^2), dim))
|
||||||
|
# time <- system.time(replicate(reps,
|
||||||
|
# ising_m2(theta, use_MC = FALSE)
|
||||||
|
# ))
|
||||||
|
# time <- time[["elapsed"]] / reps
|
||||||
|
# cat(sprintf(" dim = %3d, reps = %3d, time per rep = %.2e s\n", dim, reps, time))
|
||||||
|
# time
|
||||||
|
# })
|
||||||
|
# list(dims = dims, times = times)
|
||||||
|
# })
|
||||||
|
|
||||||
|
# perft_MC <- local({
|
||||||
|
# dims <- c(2:21, 30, 40, 70, 100)
|
||||||
|
|
||||||
|
# cat("Monte-Carlo perft:\n")
|
||||||
|
# times <- sapply(dims, function(dim) {
|
||||||
|
# reps <- if (dim < 20) 25 else if (dim < 40) 10 else 4
|
||||||
|
# theta <- vech(matrix(rnorm(dim^2), dim))
|
||||||
|
# time <- system.time(replicate(reps,
|
||||||
|
# ising_m2(theta, use_MC = TRUE)
|
||||||
|
# ))
|
||||||
|
# time <- time[["elapsed"]] / reps
|
||||||
|
# cat(sprintf(" dim = %3d, reps = %3d, time per rep = %.2e s\n", dim, reps, time))
|
||||||
|
# time
|
||||||
|
# })
|
||||||
|
|
||||||
|
# list(dims = dims, times = times)
|
||||||
|
# })
|
||||||
|
|
||||||
|
# perft_MC_thrd <- local({
|
||||||
|
# dims <- c(2:21, 30, 40, 70, 100)
|
||||||
|
|
||||||
|
# cat("Monte-Carlo Multi-Threaded perft:\n")
|
||||||
|
# times <- sapply(dims, function(dim) {
|
||||||
|
# reps <- if (dim < 15) 25 else if (dim < 40) 10 else 4
|
||||||
|
# theta <- vech(matrix(rnorm(dim^2), dim))
|
||||||
|
# time <- system.time(replicate(reps,
|
||||||
|
# ising_m2(theta, use_MC = TRUE, nr_threads = 6L)
|
||||||
|
# ))
|
||||||
|
# time <- time[["elapsed"]] / reps
|
||||||
|
# cat(sprintf(" dim = %3d, reps = %3d, time per rep = %.2e s\n", dim, reps, time))
|
||||||
|
# time
|
||||||
|
# })
|
||||||
|
|
||||||
|
# list(dims = dims, times = times)
|
||||||
|
# })
|
||||||
|
|
||||||
|
# # Plot results
|
||||||
|
# par(mfrow = c(1, 1))
|
||||||
|
# plot(
|
||||||
|
# range(c(perft_MC_thrd$dims, perft_MC$dims, perft_exact$dims)),
|
||||||
|
# range(c(perft_MC_thrd$times, perft_MC$times, perft_exact$times)),
|
||||||
|
# type = "n", log = "xy", xlab = "Dimension p", ylab = "Time", xaxt = "n", yaxt = "n",
|
||||||
|
# main = "Performance Comparison"
|
||||||
|
# )
|
||||||
|
# # Add custom Y-axis
|
||||||
|
# x.major.ticks <- as.vector(outer(c(2, 5, 10), 10^(0:5)))
|
||||||
|
# x.minor.ticks <- as.vector(outer(2:10, 10^(0:5)))
|
||||||
|
# axis(1, x.major.ticks, labels = as.integer(x.major.ticks))
|
||||||
|
# axis(1, x.minor.ticks, labels = NA, tcl = -0.25, lwd = 0, lwd.ticks = 1)
|
||||||
|
# abline(v = x.major.ticks, col = "gray", lty = "dashed")
|
||||||
|
# abline(v = x.minor.ticks, col = "lightgray", lty = "dotted")
|
||||||
|
# # Add custom Y-axis
|
||||||
|
# y.major.ticks <- c(10^(-9:1), 60, 600, 3600)
|
||||||
|
# y.labels <- c(
|
||||||
|
# expression(paste(n, s)),
|
||||||
|
# expression(paste(10, n, s)),
|
||||||
|
# expression(paste(100, n, s)),
|
||||||
|
# expression(paste(mu, s)),
|
||||||
|
# expression(paste(10, mu, s)),
|
||||||
|
# expression(paste(100, mu, s)),
|
||||||
|
# expression(paste(1, m, s)),
|
||||||
|
# expression(paste(10, m, s)),
|
||||||
|
# expression(paste(100, m, s)),
|
||||||
|
# expression(paste(1, s)),
|
||||||
|
# expression(paste(10, s)),
|
||||||
|
# expression(paste(1, min)),
|
||||||
|
# expression(paste(10, min)),
|
||||||
|
# expression(paste(1, h))
|
||||||
|
# )
|
||||||
|
# y.minor.ticks <- c(as.vector(outer((1:10), 10^(-10:0))), 10 * (1:6), 60 * (2:10), 600 * (2:6))
|
||||||
|
# axis(2, at = y.major.ticks, labels = y.labels)
|
||||||
|
# axis(2, at = y.minor.ticks, labels = NA, tcl = -0.25, lwd = 0, lwd.ticks = 1)
|
||||||
|
# abline(h = y.major.ticks, col = "gray", lty = "dashed")
|
||||||
|
# abline(h = y.minor.ticks, col = "lightgray", lty = "dotted")
|
||||||
|
# legend("bottomright", col = c("red", "darkgreen", "blue"), lty = c(1, 1, 1),
|
||||||
|
# bg = "white",
|
||||||
|
# legend = c(
|
||||||
|
# expression(paste("Exact ", O(2^p))),
|
||||||
|
# expression(paste("MC ", O(p^2))),
|
||||||
|
# expression(paste("MC Thrd ", O(p^2)))
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
# with(perft_exact, {
|
||||||
|
# points(dims, times, pch = 16, col = "red")
|
||||||
|
# with(list(dims = tail(dims, -4), times = tail(times, -4)), {
|
||||||
|
# lines(dims, exp(predict(lm(log(times) ~ dims))), col = "red")
|
||||||
|
# })
|
||||||
|
# })
|
||||||
|
# with(perft_MC, {
|
||||||
|
# points(dims, times, pch = 16, col = "darkgreen")
|
||||||
|
# lines(dims, predict(lm(sqrt(times) ~ dims))^2, col = "darkgreen")
|
||||||
|
# })
|
||||||
|
# with(perft_MC_thrd, {
|
||||||
|
# points(dims, times, pch = 16, col = "blue")
|
||||||
|
# lines(dims, predict(lm(sqrt(times) ~ dims))^2, col = "blue")
|
||||||
|
# })
|
||||||
|
# })
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # dim <- 10
|
||||||
|
# # theta <- vech(matrix(rnorm(dim^2, 0, 1), dim, dim))
|
||||||
|
# # nr_threads <- 6L
|
||||||
|
|
||||||
|
# # (m2.exact <- ising_m2(theta, use_MC = FALSE))
|
||||||
|
# # (m2.MC <- ising_m2(theta, use_MC = TRUE))
|
||||||
|
# # (m2.MC_thrd <- ising_m2(theta, use_MC = TRUE, nr_threads = nr_threads))
|
||||||
|
|
||||||
|
# # tcrossprod(ising_sample(1e4, theta)) / 1e4
|
||||||
|
|
||||||
|
|
||||||
|
# local({
|
||||||
|
# dim <- 20
|
||||||
|
# theta <- vech(matrix(rnorm(dim^2, 0, 1), dim, dim))
|
||||||
|
|
||||||
|
# A <- matrix(NA_real_, dim, dim)
|
||||||
|
# A[lower.tri(A, diag = TRUE)] <- theta
|
||||||
|
# A[lower.tri(A)] <- A[lower.tri(A)] / 2
|
||||||
|
# A[upper.tri(A)] <- t(A)[upper.tri(A)]
|
||||||
|
|
||||||
|
# nr_threads <- 6L
|
||||||
|
|
||||||
|
# time.exact <- system.time(m2.exact <- ising_m2(theta, use_MC = FALSE))
|
||||||
|
# time.MC <- system.time(m2.MC <- ising_m2(theta, use_MC = TRUE))
|
||||||
|
# time.MC_thrd <- system.time(m2.MC_thrd <- ising_m2(A, use_MC = TRUE, nr_threads = nr_threads))
|
||||||
|
# time.sample <- system.time(m2.sample <- tcrossprod(ising_sample(1e4, theta)) / 1e4)
|
||||||
|
|
||||||
|
# range <- range(m2.exact, m2.MC, m2.MC_thrd)
|
||||||
|
|
||||||
|
# par(mfrow = c(2, 2))
|
||||||
|
# matrixImage(m2.exact, main = sprintf("M2 Exact (time %.2f s)", time.exact[["elapsed"]]), zlim = range)
|
||||||
|
# matrixImage(m2.MC, main = sprintf("M2 MC (time %.2f s)", time.MC[["elapsed"]]), zlim = range)
|
||||||
|
# matrixImage(m2.MC_thrd, main = sprintf("M2 MC (%d threads, time %.2f s)", nr_threads, time.MC_thrd[["elapsed"]]), zlim = range)
|
||||||
|
# matrixImage(m2.sample, main = sprintf("E_n(X X') (time %.2f s)", time.sample[["elapsed"]]), zlim = range)
|
||||||
|
# # matrixImage(abs(m2.exact - m2.MC), main = "Abs. Error (Exact to MC)", zlim = c(-1, 1))
|
||||||
|
# })
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # Simulation
|
||||||
|
# dims <- c(5, 10, 15, 20)
|
||||||
|
# config.grid <- expand.grid(
|
||||||
|
# nr_samples = c(10, 100, 1000, 10000),
|
||||||
|
# warmup = c(0, 2, 10, 100),
|
||||||
|
# dim = dims
|
||||||
|
# )
|
||||||
|
# params <- Map(function(dim) vech(matrix(rnorm(dim^2, 0, 1), dim, dim)), dims)
|
||||||
|
# names(params) <- dims
|
||||||
|
# m2s.exact <- Map(ising_m2, params, use_MC = FALSE)
|
||||||
|
|
||||||
|
# sim <- data.frame(t(apply(config.grid, 1, function(conf) {
|
||||||
|
# # get same theta for every dimension
|
||||||
|
# theta <- params[[as.character(conf["dim"])]]
|
||||||
|
|
||||||
|
# m2.exact <- m2s.exact[[as.character(conf["dim"])]]
|
||||||
|
|
||||||
|
# rep <- replicate(25, {
|
||||||
|
# time <- system.time(
|
||||||
|
# m2.MC <- ising_m2(theta, nr_samples = conf["nr_samples"], warmup = conf["warmup"], use_MC = TRUE)
|
||||||
|
# )
|
||||||
|
# c(mse = mean((m2.exact - m2.MC)^2), err = max(abs(m2.exact - m2.MC)), time = time[["elapsed"]])
|
||||||
|
# })
|
||||||
|
|
||||||
|
# cat(sprintf("dim = %d, nr_samples = %6d, warmup = %6d, mse = %.4f\n",
|
||||||
|
# conf["dim"], conf["nr_samples"], conf["warmup"], mean(rep["mse", ])))
|
||||||
|
|
||||||
|
# c(
|
||||||
|
# conf,
|
||||||
|
# mse = mean(rep["mse", ]), sd = sd(rep["mse", ]), merr = mean(rep["err", ]),
|
||||||
|
# time = mean(rep["time", ])
|
||||||
|
# )
|
||||||
|
# })))
|
||||||
|
|
||||||
|
# plot.sim <- function(sim, main) {
|
||||||
|
# with(sim, {
|
||||||
|
# par(mfrow = c(2, 2), mar = c(5, 4, 4, 4) + 0.1)
|
||||||
|
|
||||||
|
# # Effect of Nr. of samples
|
||||||
|
# plot(range(nr_samples), range(mse - sd, mse + sd),
|
||||||
|
# type = "n", bty = "n", log = "xy", yaxt = "n",
|
||||||
|
# xlab = "Nr. Samples", ylab = "MSE",
|
||||||
|
# main = "Sample Size Effect (MSE)")
|
||||||
|
# groups <- split(sim, warmup)
|
||||||
|
# for (i in seq_along(groups)) {
|
||||||
|
# with(groups[[i]], {
|
||||||
|
# lines(nr_samples, mse, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
# lines(nr_samples, mse - sd, col = i, lty = 2)
|
||||||
|
# lines(nr_samples, mse + sd, col = i, lty = 2)
|
||||||
|
# })
|
||||||
|
# }
|
||||||
|
# right <- nr_samples == max(nr_samples)
|
||||||
|
# axis(4, at = mse[right], labels = warmup[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
# mtext("Warmup", side = 4, line = 1.5, at = exp(mean(range(log(mse[right])))))
|
||||||
|
# y.pow <- -10:-1
|
||||||
|
# axis(2, at = c(1, 10^y.pow),
|
||||||
|
# labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# # Effect warmup length
|
||||||
|
# plot(range(warmup + 1), range(mse - sd, mse + sd),
|
||||||
|
# type = "n", bty = "n", log = "xy", xaxt = "n", yaxt = "n",
|
||||||
|
# xlab = "Warmup", ylab = "MSE",
|
||||||
|
# main = "Warmup Effect (MSE)")
|
||||||
|
# axis(1, warmup + 1, labels = as.integer(warmup))
|
||||||
|
# groups <- split(sim, nr_samples)
|
||||||
|
# for (i in seq_along(groups)) {
|
||||||
|
# with(groups[[i]], {
|
||||||
|
# lines(warmup + 1, mse, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
# lines(warmup + 1, mse - sd, col = i, lty = 2)
|
||||||
|
# lines(warmup + 1, mse + sd, col = i, lty = 2)
|
||||||
|
# })
|
||||||
|
# }
|
||||||
|
# right <- warmup == max(warmup)
|
||||||
|
# axis(4, at = mse[right], labels = nr_samples[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
# mtext("Nr. Samples", side = 4, line = 1.5, at = exp(mean(range(log(mse[right])))))
|
||||||
|
# axis(2, at = c(1, 10^y.pow),
|
||||||
|
# labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# # Effect of Nr. of samples
|
||||||
|
# plot(range(nr_samples), range(merr),
|
||||||
|
# type = "n", bty = "n", log = "xy", yaxt = "n",
|
||||||
|
# xlab = "Nr. Samples", ylab = "Max Abs Error Mean",
|
||||||
|
# main = "Sample Size Effect (Abs Error)")
|
||||||
|
# groups <- split(sim, warmup)
|
||||||
|
# for (i in seq_along(groups)) {
|
||||||
|
# with(groups[[i]], {
|
||||||
|
# lines(nr_samples, merr, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
# })
|
||||||
|
# }
|
||||||
|
# right <- nr_samples == max(nr_samples)
|
||||||
|
# axis(4, at = merr[right], labels = warmup[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
# mtext("Warmup", side = 4, line = 1.5, at = exp(mean(range(log(merr[right])))))
|
||||||
|
# y.pow <- -10:-1
|
||||||
|
# axis(2, at = c(1, 10^y.pow),
|
||||||
|
# labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
|
||||||
|
# # Effect of warmup length
|
||||||
|
# plot(range(warmup + 1), range(merr),
|
||||||
|
# type = "n", bty = "n", log = "xy", xaxt = "n", yaxt = "n",
|
||||||
|
# xlab = "Warmup", ylab = "Max Abs Error Mean",
|
||||||
|
# main = "Warmup Effect (Abs Error)")
|
||||||
|
# axis(1, warmup + 1, labels = as.integer(warmup))
|
||||||
|
# groups <- split(sim, nr_samples)
|
||||||
|
# for (i in seq_along(groups)) {
|
||||||
|
# with(groups[[i]], {
|
||||||
|
# lines(warmup + 1, merr, col = i, lwd = 2, type = "b", pch = 16)
|
||||||
|
# })
|
||||||
|
# }
|
||||||
|
# right <- warmup == max(warmup)
|
||||||
|
# axis(4, at = merr[right], labels = nr_samples[right], lwd = 0, lwd.ticks = 1, line = -1.5)
|
||||||
|
# mtext("Nr. Samples", side = 4, line = 1.5, at = exp(mean(range(log(merr[right])))))
|
||||||
|
# axis(2, at = c(1, 10^y.pow),
|
||||||
|
# labels = c(1, sapply(y.pow, function(pow) eval(substitute(expression(10^i), list(i = pow))))))
|
||||||
|
# })
|
||||||
|
|
||||||
|
# # Add common title
|
||||||
|
# mtext(main, side = 3, line = -2, outer = TRUE, font = 2, cex = 1.5)
|
||||||
|
# }
|
||||||
|
|
||||||
|
# plot.sim(subset(sim, sim$dim == 5), main = "Dim = 5")
|
||||||
|
# plot.sim(subset(sim, sim$dim == 10), main = "Dim = 10")
|
||||||
|
# plot.sim(subset(sim, sim$dim == 15), main = "Dim = 15")
|
||||||
|
# plot.sim(subset(sim, sim$dim == 20), main = "Dim = 20")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# dim <- 120
|
||||||
|
# params <- rnorm(dim * (dim + 1) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
# A <- matrix(NA_real_, dim, dim)
|
||||||
|
# A[lower.tri(A, diag = TRUE)] <- params
|
||||||
|
# A[lower.tri(A)] <- A[lower.tri(A)] / 2
|
||||||
|
# A[upper.tri(A)] <- t(A)[upper.tri(A)]
|
||||||
|
|
||||||
|
# seed <- abs(as.integer(100000 * rnorm(1)))
|
||||||
|
# all.equal(
|
||||||
|
# { set.seed(seed); ising_sample(11, params) },
|
||||||
|
# { set.seed(seed); ising_sample(11, A) }
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# x <- sample(0:1, 10, TRUE)
|
||||||
|
|
||||||
|
# sum(vech(outer(x, x)) * params)
|
||||||
|
# sum(x * (A %*% x))
|
||||||
|
|
||||||
|
# # M <- matrix(NA, dim, dim)
|
||||||
|
# # M[lower.tri(M, diag = TRUE)] <- seq_len(dim * (dim + 1) / 2) - 1
|
||||||
|
# # rownames(M) <- (1:dim) - 1
|
||||||
|
# # colnames(M) <- (1:dim) - 1
|
||||||
|
# # print.table(M)
|
||||||
|
|
||||||
|
# # i <- seq(0, dim - 1)
|
||||||
|
# # (i * (2 * dim + 1 - i)) / 2
|
||||||
|
|
||||||
|
# # I <- 0
|
||||||
|
# # for (i in seq(0, dim - 1)) {
|
||||||
|
# # print(I)
|
||||||
|
# # I <- I + dim - i
|
||||||
|
# # }
|
||||||
|
|
||||||
|
# m2.exact <- vech.pinv(ising_m2(params, use_MC = FALSE))
|
||||||
|
# m2.MC <- vech.pinv(ising_m2(params, use_MC = TRUE))
|
||||||
|
# m2.mat <- tcrossprod(ising_sample(1e4, A)) / 1e4
|
||||||
|
# m2.vech <- tcrossprod(ising_sample(1e4, params)) / 1e4
|
||||||
|
|
||||||
|
# par(mfrow = c(2, 2))
|
||||||
|
|
||||||
|
# matrixImage(m2.exact, main = "exact")
|
||||||
|
# matrixImage(m2.MC, main = "MC")
|
||||||
|
# matrixImage(m2.mat, main = "Sample mat")
|
||||||
|
# matrixImage(m2.vech, main = "Sample vech")
|
|
@ -0,0 +1,205 @@
|
||||||
|
#' Some utility function used in simulations
|
||||||
|
|
||||||
|
#' Computes the orthogonal projection matrix onto the span of `A`
|
||||||
|
proj <- function(A) tcrossprod(A, A %*% solve(crossprod(A, A)))
|
||||||
|
|
||||||
|
|
||||||
|
#' Logging Errors and Warnings
|
||||||
|
#'
|
||||||
|
#' Register a global warning and error handler for logging warnings/errors with
|
||||||
|
#' current simulation repetition session informatin allowing to reproduce problems
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' # Usage
|
||||||
|
#' globalCallingHandlers(list(
|
||||||
|
#' message = exceptionLogger("warning.log"),
|
||||||
|
#' warning = exceptionLogger("warning.log"),
|
||||||
|
#' error = exceptionLogger("error.log")
|
||||||
|
#' ))
|
||||||
|
#' # Do some stuff where an error might occure
|
||||||
|
#' for (rep in 1:1000) {
|
||||||
|
#' # Store additional information logged with an error when an exception occures
|
||||||
|
#' storeExceptionInfo(rep = rep)
|
||||||
|
#' # Do work
|
||||||
|
#' stopifnot(rep < 17)
|
||||||
|
#' }
|
||||||
|
#'
|
||||||
|
assign(".exception.info", NULL, env = .GlobalEnv)
|
||||||
|
exceptionLogger <- function(file.name) {
|
||||||
|
force(file.name)
|
||||||
|
function(ex) {
|
||||||
|
log <- file(file.name, open = "a+")
|
||||||
|
cat("\n### Log At: ", format(Sys.time()), "\n", file = log)
|
||||||
|
cat("# Exception:\n", file = log)
|
||||||
|
cat(as.character.error(ex), file = log)
|
||||||
|
cat("\n# Exception Info:\n", file = log)
|
||||||
|
dump(".exception.info", envir = .GlobalEnv, file = log)
|
||||||
|
cat("\n# Traceback:\n", file = log)
|
||||||
|
|
||||||
|
# add Traceback (see: `traceback()` which the following is addapted from)
|
||||||
|
n <- length(x <- .traceback(NULL, max.lines = -1L))
|
||||||
|
if (n == 0L) {
|
||||||
|
cat("No traceback available", "\n", file = log)
|
||||||
|
} else {
|
||||||
|
for (i in 1L:n) {
|
||||||
|
xi <- x[[i]]
|
||||||
|
label <- paste0(n - i + 1L, ": ")
|
||||||
|
m <- length(xi)
|
||||||
|
srcloc <- if (!is.null(srcref <- attr(xi, "srcref"))) {
|
||||||
|
srcfile <- attr(srcref, "srcfile")
|
||||||
|
paste0(" at ", basename(srcfile$filename), "#", srcref[1L])
|
||||||
|
}
|
||||||
|
if (isTRUE(attr(xi, "truncated"))) {
|
||||||
|
xi <- c(xi, " ...")
|
||||||
|
m <- length(xi)
|
||||||
|
}
|
||||||
|
if (!is.null(srcloc)) {
|
||||||
|
xi[m] <- paste0(xi[m], srcloc)
|
||||||
|
}
|
||||||
|
if (m > 1) {
|
||||||
|
label <- c(label, rep(substr(" ", 1L,
|
||||||
|
nchar(label, type = "w")), m - 1L))
|
||||||
|
}
|
||||||
|
cat(paste0(label, xi), sep = "\n", file = log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#' Used in conjuntion with `exceptionLogger()`
|
||||||
|
storeExceptionInfo <- function(...) {
|
||||||
|
info <- list(...)
|
||||||
|
info$RNGking <- RNGkind()
|
||||||
|
info$.Random.seed <- get0(".Random.seed", envir = .GlobalEnv)
|
||||||
|
.GlobalEnv$.exception.info <- info
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
### Simulation logging routine
|
||||||
|
#' @examples
|
||||||
|
#' # Create a CSV logger
|
||||||
|
#' logger <- CSV.logger("test.csv", header = c("A", "B", "C"))
|
||||||
|
#'
|
||||||
|
#' # Store some values in current environment
|
||||||
|
#' A <- 1
|
||||||
|
#' B <- 2
|
||||||
|
#' # write values to csv file (order irrelevent)
|
||||||
|
#' logger(C = -3, B = -2)
|
||||||
|
#'
|
||||||
|
#' # another line
|
||||||
|
#' A <- 10
|
||||||
|
#' B <- 20
|
||||||
|
#' logger(B = -20, C = -30)
|
||||||
|
#'
|
||||||
|
#' # read the file back in
|
||||||
|
#' read.csv("test.csv")
|
||||||
|
#'
|
||||||
|
#' # In simulations it is often usefull to time stamp the files
|
||||||
|
#' nr <- 5
|
||||||
|
#' logger <- CSV.logger(
|
||||||
|
#' sprintf("test-nr%03d-%s.csv", nr, format(Sys.time(), "%Y%m%dT%H%M")),
|
||||||
|
#' header = c("A", "B", "C")
|
||||||
|
#' )
|
||||||
|
#'
|
||||||
|
CSV.logger <- function(file.name, header) {
|
||||||
|
force(file.name)
|
||||||
|
|
||||||
|
# CSV header, used to ensure correct value/column mapping when writing to file
|
||||||
|
force(header)
|
||||||
|
cat(paste0(header, collapse = ","), "\n", sep = "", file = file.name)
|
||||||
|
|
||||||
|
function(...) {
|
||||||
|
# get directly provided data
|
||||||
|
arg.data <- list(...)
|
||||||
|
# all arguments must be given with a name
|
||||||
|
if (length(arg.data) && is.null(names(arg.data))) {
|
||||||
|
stop("Arguments must be given with names")
|
||||||
|
}
|
||||||
|
# check if all elements have a described CSV header column
|
||||||
|
unknown <- !(names(arg.data) %in% header)
|
||||||
|
if (any(unknown)) {
|
||||||
|
stop("Got unknown columns: ", paste0(names(arg.data)[unknown], collapse = ", "))
|
||||||
|
}
|
||||||
|
# get missing values from environment
|
||||||
|
missing <- !(header %in% names(arg.data))
|
||||||
|
env <- parent.frame()
|
||||||
|
data <- c(arg.data, mget(header[missing], envir = env))
|
||||||
|
# Format all aguments
|
||||||
|
data <- Map(format, data)
|
||||||
|
# collaps into single line
|
||||||
|
line <- paste0(data[header], collapse = ",")
|
||||||
|
# write data line to file
|
||||||
|
cat(line, "\n", sep = "", file = file.name, append = TRUE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Set colors for every method
|
||||||
|
methods <- c("gmlm", "pca", "hopca", "tsir", "mgcca", "lpca", "clpca", "tnormal", "sir")
|
||||||
|
col.methods <- palette.colors(n = length(methods), palette = "Okabe-Ito", recycle = FALSE)
|
||||||
|
names(col.methods) <- methods
|
||||||
|
|
||||||
|
|
||||||
|
# Comparison plot of one measure for a simulation
|
||||||
|
plot.sim <- function(sim, measure.name, ..., ylim = c(0, 1)) {
|
||||||
|
par.default <- par(pch = 16, bty = "n", lty = "solid", lwd = 1.5)
|
||||||
|
|
||||||
|
# # Set colors for every method
|
||||||
|
# methods <- c("gmlm", "pca", "hopca", "tsir", "mgcca", "lpca", "clpca", "tnormal")
|
||||||
|
# col.methods <- palette.colors(n = length(methods), palette = "Okabe-Ito", recycle = FALSE)
|
||||||
|
# names(col.methods) <- methods
|
||||||
|
|
||||||
|
# Remain sample size grouping variable to avoid conflicts
|
||||||
|
aggr.mean <- aggregate(sim, list(sampleSize = sim$sample.size), mean)
|
||||||
|
aggr.median <- aggregate(sim, list(sampleSize = sim$sample.size), median)
|
||||||
|
aggr.sd <- aggregate(sim, list(sampleSize = sim$sample.size), sd)
|
||||||
|
aggr.min <- aggregate(sim, list(sampleSize = sim$sample.size), min)
|
||||||
|
aggr.max <- aggregate(sim, list(sampleSize = sim$sample.size), max)
|
||||||
|
|
||||||
|
with(aggr.mean, {
|
||||||
|
plot(range(sampleSize), ylim, type = "n", ...)
|
||||||
|
for (dist.name in ls(pattern = paste0("^", measure.name))) {
|
||||||
|
mean <- get(dist.name)
|
||||||
|
median <- aggr.median[aggr.sd$sampleSize == sampleSize, dist.name]
|
||||||
|
sd <- aggr.sd[aggr.sd$sampleSize == sampleSize, dist.name]
|
||||||
|
min <- aggr.min[aggr.sd$sampleSize == sampleSize, dist.name]
|
||||||
|
max <- aggr.max[aggr.sd$sampleSize == sampleSize, dist.name]
|
||||||
|
method <- tail(strsplit(dist.name, ".", fixed = TRUE)[[1]], 1)
|
||||||
|
col <- col.methods[method]
|
||||||
|
lines(sampleSize, mean, type = "o", col = col, lty = 1, lwd = 2 + (method == "gmlm"))
|
||||||
|
lines(sampleSize, mean + sd, col = col, lty = 2, lwd = 0.8)
|
||||||
|
lines(sampleSize, mean - sd, col = col, lty = 2, lwd = 0.8)
|
||||||
|
lines(sampleSize, median, col = col, lty = 1, lwd = 1)
|
||||||
|
lines(sampleSize, min, col = col, lty = 3, lwd = 0.6)
|
||||||
|
lines(sampleSize, max, col = col, lty = 3, lwd = 0.6)
|
||||||
|
}
|
||||||
|
|
||||||
|
legend("topright", col = col.methods, lty = 1, legend = names(col.methods),
|
||||||
|
bty = "n", lwd = par("lwd"), pch = par("pch"))
|
||||||
|
})
|
||||||
|
|
||||||
|
# reset plotting default prameters
|
||||||
|
par(par.default)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
timer.env <- new.env()
|
||||||
|
start.timer <- function() {
|
||||||
|
assign("start.time", proc.time()[["elapsed"]], envir = timer.env)
|
||||||
|
}
|
||||||
|
clear.timer <- function() {
|
||||||
|
assign("total.time", 0, envir = timer.env)
|
||||||
|
}
|
||||||
|
end.timer <- function() {
|
||||||
|
end.time <- proc.time()[["elapsed"]]
|
||||||
|
start.time <- get("start.time", envir = timer.env)
|
||||||
|
total.time <- get0("total.time", envir = timer.env)
|
||||||
|
if (is.null(total.time)) {
|
||||||
|
total.time <- 0
|
||||||
|
}
|
||||||
|
elapsed <- end.time - start.time
|
||||||
|
total.time <- total.time + elapsed
|
||||||
|
assign("total.time", total.time, envir = timer.env)
|
||||||
|
c(elapsed = elapsed, total.time = total.time)
|
||||||
|
}
|
|
@ -58,7 +58,6 @@ export(kpir.momentum)
|
||||||
export(kpir.new)
|
export(kpir.new)
|
||||||
export(kronperm)
|
export(kronperm)
|
||||||
export(mat)
|
export(mat)
|
||||||
export(matProj)
|
|
||||||
export(matpow)
|
export(matpow)
|
||||||
export(matrixImage)
|
export(matrixImage)
|
||||||
export(mcov)
|
export(mcov)
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
#' A simple higher order (multi-way) canonical correlation analysis.
|
||||||
|
#'
|
||||||
|
#' @param X multi-dimensional array
|
||||||
|
#' @param Y multi-dimensional array with the same nr. of dimensions and equal
|
||||||
|
#' sample axis to `X`.
|
||||||
|
#' @param sample.axis integer indicationg which axis enumerates observations
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
HOCCA <- function(X, Y, sample.axis = length(dim(X)), centerX = TRUE, centerY = TRUE) {
|
||||||
|
|
||||||
|
# ensure sample axis is the last axis
|
||||||
|
if (!missing(sample.axis)) {
|
||||||
|
modes <- seq_along(dim(X))[-sample.axis]
|
||||||
|
X <- aperm(X, c(modes, sample.axis))
|
||||||
|
Y <- aperm(Y, c(modes, sample.axis))
|
||||||
|
}
|
||||||
|
modes <- seq_len(length(dim(X)) - 1L)
|
||||||
|
dimX <- head(dim(X), -1L)
|
||||||
|
dimF <- head(dim(F), -1L)
|
||||||
|
sample.size <- tail(dim(X), 1L)
|
||||||
|
|
||||||
|
# center `X` and `Y`
|
||||||
|
if (centerX) {
|
||||||
|
X <- X - as.vector(rowMeans(X, dims = length(dim(X)) - 1L))
|
||||||
|
}
|
||||||
|
if (centerY) {
|
||||||
|
Y <- Y - as.vector(rowMeans(Y, dims = length(dim(Y)) - 1L))
|
||||||
|
}
|
||||||
|
|
||||||
|
# estimate marginal covariance matrices
|
||||||
|
CovXX <- Map(function(mode) mcrossprod(X, mode = mode) / prod(dim(X)[-mode]), modes)
|
||||||
|
CovYY <- Map(function(mode) mcrossprod(Y, mode = mode) / prod(dim(Y)[-mode]), modes)
|
||||||
|
# and the "covariance tensor"
|
||||||
|
CovXY <- array(tcrossprod(mat(X, modes), mat(Y, modes)) / sample.size, c(dimX, dimF))
|
||||||
|
|
||||||
|
# Compute standardized X and Y correlation tensor
|
||||||
|
SCovXY <- mlm(CovXY, Map(matpow, c(CovXX, CovYY), -1 / 2))
|
||||||
|
|
||||||
|
# mode-wise canonical correlation directions
|
||||||
|
hosvd <- HOSVD(SCovXY, nu = rep(pmin(dimX, dimF), 2L))
|
||||||
|
dirsX <- hosvd$Us[modes]
|
||||||
|
dirsY <- hosvd$Us[modes + length(modes)]
|
||||||
|
|
||||||
|
list(dirsX = dirsX, dirsY = dirsY)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# c(dirsX, dirsY) %<-% HOCCA(X, F)
|
||||||
|
# B.hocca <- Reduce(kronecker, rev(Map(tcrossprod, dirsX, dirsY)))
|
||||||
|
# dist.subspace(B.true, B.hocca, normalize = TRUE)
|
||||||
|
# dist.subspace(B.true, Reduce(kronecker, rev(dirsX)), normalize = TRUE)
|
||||||
|
|
||||||
|
# cca <- cancor(mat(X, 4), mat(F, 4))
|
||||||
|
# B.cca <- tcrossprod(cca$xcoef[, prod(dim(X)[-4])], cca$xcoef[, prod(dim(F)[-4])])
|
||||||
|
# dist.subspace(B.true, cca$xcoef[, prod(dim(X)[-4])], normalize = TRUE)
|
|
@ -0,0 +1,36 @@
|
||||||
|
#' Sliced Inverse Regression
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
SIR <- function(X, y, d, nr.slices = 10L, slice.method = c("cut", "ecdf")) {
|
||||||
|
|
||||||
|
if (!(is.factor(y) || is.integer(y))) {
|
||||||
|
slice.method <- match.arg(slice.method)
|
||||||
|
if (slice.method == "ecdf") {
|
||||||
|
y <- cut(ecdf(y)(y), nr.slices)
|
||||||
|
} else {
|
||||||
|
y <- cut(y, nr.slices)
|
||||||
|
# ensure there are no empty slices
|
||||||
|
if (any(table(y) == 0)) {
|
||||||
|
y <- as.factor(as.integer(y))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Center `X`
|
||||||
|
Z <- scale(X, scale = FALSE)
|
||||||
|
|
||||||
|
# Split `Z` into slices determined by `y`
|
||||||
|
slices <- Map(function(i) Z[i, , drop = FALSE], split(seq_along(y), y))
|
||||||
|
|
||||||
|
# Sizes and Means for each slice
|
||||||
|
slice.sizes <- mapply(nrow, slices)
|
||||||
|
slice.means <- Map(colMeans, slices)
|
||||||
|
|
||||||
|
# Inbetween slice covariances
|
||||||
|
sCov <- Reduce(`+`, Map(function(mean_s, n_s) {
|
||||||
|
n_s * tcrossprod(mean_s)
|
||||||
|
}, slice.means, slice.sizes)) / nrow(X)
|
||||||
|
|
||||||
|
# Compute EDR directions
|
||||||
|
La.svd(sCov, d, 0L)$u
|
||||||
|
}
|
|
@ -109,5 +109,5 @@ TSIR <- function(X, y, d, sample.axis = 1L,
|
||||||
|
|
||||||
# reductions matrices `Omega_k^-1 Gamma_k` where there (reverse) kronecker
|
# reductions matrices `Omega_k^-1 Gamma_k` where there (reverse) kronecker
|
||||||
# product spans the central tensor subspace (CTS) estimate
|
# product spans the central tensor subspace (CTS) estimate
|
||||||
Map(solve, Omegas, Gammas)
|
structure(Map(solve, Omegas, Gammas), mcov = Omegas, Gammas = Gammas)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
#' Determinant of a matrix
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
La.det <- function(A) {
|
||||||
|
storage.mode(A) <- "double"
|
||||||
|
.Call("C_det", A, PACKAGE = "tensorPredictors")
|
||||||
|
}
|
|
@ -4,7 +4,7 @@
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
proj.betas = NULL, proj.Omegas = NULL,
|
proj.betas = NULL, proj.Omegas = NULL, Omega.mask = NULL,
|
||||||
max.iter = 1000L,
|
max.iter = 1000L,
|
||||||
eps = sqrt(.Machine$double.eps),
|
eps = sqrt(.Machine$double.eps),
|
||||||
step.size = 1e-3,
|
step.size = 1e-3,
|
||||||
|
@ -112,7 +112,7 @@ gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
matX <- mat(X, sample.axis)
|
matX <- mat(X, sample.axis)
|
||||||
degen <- crossprod(matX) == 0
|
degen <- crossprod(matX) == 0
|
||||||
degen.mask <- which(degen)
|
degen.mask <- which(degen)
|
||||||
# If there are degenerate combination, compute an (arbitrary) bound the
|
# If there are degenerate combination, compute an (arbitrary) bound of the
|
||||||
# log odds parameters of those combinations
|
# log odds parameters of those combinations
|
||||||
if (any(degen.mask)) {
|
if (any(degen.mask)) {
|
||||||
degen.ind <- arrayInd(degen.mask, dim(degen))
|
degen.ind <- arrayInd(degen.mask, dim(degen))
|
||||||
|
@ -145,7 +145,7 @@ gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize mean squared gradients
|
# Initialize mean squared gradients
|
||||||
grad2_betas <- Map(array, 0, Map(dim, betas))
|
grad2_betas <- Map(array, 0, Map(dim, betas))
|
||||||
grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
|
grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
|
||||||
|
|
||||||
# Keep track of the last loss to accumulate loss difference sign changes
|
# Keep track of the last loss to accumulate loss difference sign changes
|
||||||
|
@ -166,6 +166,11 @@ gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
grad_betas <- Map(matrix, 0, dimX, dimF)
|
grad_betas <- Map(matrix, 0, dimX, dimF)
|
||||||
Omega <- Reduce(kronecker, rev(Omegas))
|
Omega <- Reduce(kronecker, rev(Omegas))
|
||||||
|
|
||||||
|
# Mask Omega, that is to enforce the "linear" constraint `T2`
|
||||||
|
if (!is.null(Omega.mask)) {
|
||||||
|
Omega[Omega.mask] <- 0
|
||||||
|
}
|
||||||
|
|
||||||
# second order residuals accumulator
|
# second order residuals accumulator
|
||||||
# `sum_i (X_i o X_i - E[X o X | Y = y_i])`
|
# `sum_i (X_i o X_i - E[X o X | Y = y_i])`
|
||||||
R2 <- array(0, dim = c(dimX, dimX))
|
R2 <- array(0, dim = c(dimX, dimX))
|
||||||
|
@ -186,7 +191,7 @@ gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
# accumulate loss
|
# accumulate loss
|
||||||
matX_i <- mat(eval(`X[..., i]`), modes)
|
matX_i <- mat(eval(`X[..., i]`), modes)
|
||||||
loss <- loss - (
|
loss <- loss - (
|
||||||
sum(matX_i * (params_i %*% matX_i)) + n_i * log(attr(m2_i, "prob_0"))
|
sum(matX_i * (params_i %*% matX_i)) + n_i * attr(m2_i, "log_prob_0")
|
||||||
)
|
)
|
||||||
|
|
||||||
R2_i <- tcrossprod(matX_i) - n_i * m2_i
|
R2_i <- tcrossprod(matX_i) - n_i * m2_i
|
||||||
|
@ -200,6 +205,12 @@ gmlm_ising <- function(X, F, y = NULL, sample.axis = length(dim(X)),
|
||||||
R2 <- R2 + as.vector(R2_i)
|
R2 <- R2 + as.vector(R2_i)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Apply the `T2` constraint on the Residuals as well (refer to `T2`)
|
||||||
|
# That is, we compute G2 from g2 as in Theorem 2.
|
||||||
|
if (!is.null(Omega.mask)) {
|
||||||
|
R2[Omega.mask] <- 0
|
||||||
|
}
|
||||||
|
|
||||||
grad_Omegas <- Map(function(j) {
|
grad_Omegas <- Map(function(j) {
|
||||||
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-j]), modes[-j], transposed = TRUE)
|
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-j]), modes[-j], transposed = TRUE)
|
||||||
dim(grad) <- dim(Omegas[[j]])
|
dim(grad) <- dim(Omegas[[j]])
|
||||||
|
|
|
@ -61,8 +61,14 @@ gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
|
||||||
# Residuals
|
# Residuals
|
||||||
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
||||||
|
|
||||||
|
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
||||||
|
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
||||||
|
# `Omega <- Reduce(kronecker, rev(Omegas))`.
|
||||||
|
det.Omega <- sum(mapply(function(Omega) {
|
||||||
|
sum(log(eigen(Omega, TRUE, TRUE)$values))
|
||||||
|
}, Omegas) / dimX)
|
||||||
# Initial value of the log-likelihood (scaled and constants dropped)
|
# Initial value of the log-likelihood (scaled and constants dropped)
|
||||||
loss <- mean(R * mlm(R, Omegas)) - sum(log(mapply(det, Omegas)) / dimX)
|
loss <- mean(R * mlm(R, Omegas)) - det.Omega
|
||||||
|
|
||||||
# invoke the logger
|
# invoke the logger
|
||||||
if (is.function(logger)) do.call(logger, list(
|
if (is.function(logger)) do.call(logger, list(
|
||||||
|
@ -88,7 +94,7 @@ gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
|
||||||
# Residuals
|
# Residuals
|
||||||
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
R <- X - mlm(F, Map(`%*%`, Sigmas, betas))
|
||||||
|
|
||||||
# Covariance Estimates (moment based, TODO: implement MLE estimate!)
|
# Covariance Estimates
|
||||||
Sigmas <- mcov(R, sample.axis, center = FALSE)
|
Sigmas <- mcov(R, sample.axis, center = FALSE)
|
||||||
|
|
||||||
# Computing `Omega_j`s, the j'th mode presition matrices, in conjunction
|
# Computing `Omega_j`s, the j'th mode presition matrices, in conjunction
|
||||||
|
@ -111,9 +117,16 @@ gmlm_tensor_normal <- function(X, F, sample.axis = length(dim(X)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# store last loss and compute new value
|
# store last loss
|
||||||
loss.last <- loss
|
loss.last <- loss
|
||||||
loss <- mean(R * mlm(R, Omegas)) - sum(log(mapply(det, Omegas)) / dimX)
|
# Numerically more stable version of `sum(log(mapply(det, Omegas)) / dimX)`
|
||||||
|
# which is itself equivalent to `log(det(Omega)) / prod(nrow(Omega))` where
|
||||||
|
# `Omega <- Reduce(kronecker, rev(Omegas))`.
|
||||||
|
det.Omega <- sum(mapply(function(Omega) {
|
||||||
|
sum(log(eigen(Omega, TRUE, TRUE)$values))
|
||||||
|
}, Omegas) / dimX)
|
||||||
|
# Compute new loss
|
||||||
|
loss <- mean(R * mlm(R, Omegas)) - det.Omega
|
||||||
|
|
||||||
# invoke the logger
|
# invoke the logger
|
||||||
if (is.function(logger)) do.call(logger, list(
|
if (is.function(logger)) do.call(logger, list(
|
||||||
|
|
|
@ -12,43 +12,7 @@ ising_m2 <- function(
|
||||||
)
|
)
|
||||||
|
|
||||||
M2 <- vech.pinv(m2)
|
M2 <- vech.pinv(m2)
|
||||||
attr(M2, "prob_0") <- attr(m2, "prob_0")
|
attr(M2, "log_prob_0") <- attr(m2, "log_prob_0")
|
||||||
|
|
||||||
M2
|
M2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# library(tensorPredictors)
|
|
||||||
|
|
||||||
# dimX <- c(3, 4)
|
|
||||||
# dimF <- rep(1L, length(dimX))
|
|
||||||
|
|
||||||
# betas <- Map(diag, 1, dimX, dimF)
|
|
||||||
# Omegas <- list(
|
|
||||||
# 1 - diag(dimX[1]),
|
|
||||||
# toeplitz(rev(seq(0, len = dimX[2])) / dimX[2])
|
|
||||||
# )
|
|
||||||
# Omega <- Reduce(kronecker, rev(Omegas))
|
|
||||||
|
|
||||||
# y <- array(1, dimF)
|
|
||||||
# params <- diag(as.vector(mlm(y, betas))) + Omega
|
|
||||||
|
|
||||||
# # params <- array(0, dim(Omega))
|
|
||||||
|
|
||||||
# (prob_0 <- attr(ising_m2(params), "prob_0"))
|
|
||||||
# (probs <- replicate(20, attr(ising_m2(params, use_MC = TRUE, nr_threads = 8), "prob_0")))[1]
|
|
||||||
# m <- mean(probs)
|
|
||||||
# s <- sd(probs)
|
|
||||||
|
|
||||||
# (prob_a <- (function(p, M2) {
|
|
||||||
# (1 + p * (p + 1) / 2 + 2 * sum(M2) - 2 * (p + 1) * sum(diag(M2))) / 2^p
|
|
||||||
# })(prod(dimX), ising_m2(params, use_MC = FALSE)))
|
|
||||||
|
|
||||||
# par(mar = c(1, 2, 1, 2) + 0.1)
|
|
||||||
# plot(probs, ylim = pmax(0, range(probs, prob_0, prob_a)), pch = 16, cex = 1,
|
|
||||||
# xaxt = "n", xlab = "", col = "gray", log = "y", bty = "n")
|
|
||||||
# lines(cumsum(probs) / seq_along(probs), lty = 2, lwd = 2)
|
|
||||||
# abline(h = c(m - s, m, m + s), lty = c(3, 2, 3), col = "red", lwd = 2)
|
|
||||||
# abline(h = c(prob_0, prob_a), lwd = 2)
|
|
||||||
# axis(4, at = prob_0, labels = sprintf("%.1e", prob_0))
|
|
||||||
# axis(4, at = prob_a, labels = sprintf("%.1e", prob_a))
|
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
#' Kronecker Permutation of an array
|
||||||
|
#'
|
||||||
|
#' Computes a permutation and reshaping of `A` such that
|
||||||
|
#' kronperm(B %o% C) == kronecker(B, C)
|
||||||
|
#'
|
||||||
|
#' @param A multi-dimensional array
|
||||||
|
#' @param dims dimensions `A` should have overwriting the actuall dimensions
|
||||||
|
#' @param ncomp number of "components" counting the elements of an outer product
|
||||||
|
#' used to generate `A` if it is the result of an outer product.
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' A <- array(rnorm(24), dim = c(2, 3, 4))
|
||||||
|
#' B <- array(rnorm(15), dim = c(5, 3, 1))
|
||||||
|
#' C <- array(rnorm(84), dim = c(7, 4, 3))
|
||||||
|
#'
|
||||||
|
#' all.equal(
|
||||||
|
#' kronperm(outer(A, B)),
|
||||||
|
#' kronecker(A, B)
|
||||||
|
#' )
|
||||||
|
#' all.equal(
|
||||||
|
#' kronperm(Reduce(outer, list(A, B, C)), ncomp = 3L),
|
||||||
|
#' Reduce(kronecker, list(A, B, C))
|
||||||
|
#' )
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
kronperm <- function(A, dims = dim(A), ncomp = 2L) {
|
||||||
|
# force `A` to have a multiple of `ncomp` dimensions
|
||||||
|
dim(A) <- c(dims, rep(1L, length(dims) %% ncomp))
|
||||||
|
# compute axis permutation
|
||||||
|
perm <- as.vector(t(matrix(seq_along(dim(A)), ncol = ncomp)[, ncomp:1]))
|
||||||
|
# permute elements of A
|
||||||
|
K <- aperm(A, perm, resize = FALSE)
|
||||||
|
# collapse/set dimensions
|
||||||
|
dim(K) <- apply(matrix(dim(K), ncol = ncomp), 1, prod)
|
||||||
|
K
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
projSym <- function(A) 0.5 * (A + t(A))
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
projDiag <- function(A) diag(diag(A))
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projBand <- function(dims, low, high) {
|
||||||
|
diag.index <- .row(dims) - .col(dims)
|
||||||
|
mask <- (diag.index <= low) & (-high <= diag.index)
|
||||||
|
function(A) A * mask
|
||||||
|
}
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projSymBand <- function(dims, low, high) {
|
||||||
|
diag.index <- .row(dims) - .col(dims)
|
||||||
|
mask <- (diag.index <= low) & (-high <= diag.index)
|
||||||
|
function(A) projSym(A) * mask
|
||||||
|
}
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projPSD <- function(sym = FALSE) {
|
||||||
|
if (sym) {
|
||||||
|
function(A) {
|
||||||
|
eig <- eigen(A, symmetric = TRUE)
|
||||||
|
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
function(A) {
|
||||||
|
eig <- eigen(0.5 * (A + t(A)), symmetric = TRUE)
|
||||||
|
eig$vectors %*% (pmax(0, eig$values) * t(eig$vectors))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projRank <- function(rank) {
|
||||||
|
force(rank)
|
||||||
|
function(A) {
|
||||||
|
rank <- min(dim(A), rank)
|
||||||
|
svdA <- La.svd(A, rank, rank)
|
||||||
|
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projSymRank <- function(rank) {
|
||||||
|
force(rank)
|
||||||
|
function(A) {
|
||||||
|
rank <- min(dim(A), rank)
|
||||||
|
svdA <- La.svd(0.5 * (A + t(A)), rank, rank)
|
||||||
|
svdA$u %*% (svdA$d[seq_len(rank)] * svdA$vt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
projStiefel <- function(A) {
|
||||||
|
# Using a polar decomposition of `A = Q P` via SVD `A = U D V^T`. Compaired
|
||||||
|
# to a QR decomposition the polar decomposition is unique, making it "stabel".
|
||||||
|
svdA <- La.svd(A)
|
||||||
|
svdA$u %*% svdA$vt # = Q
|
||||||
|
}
|
||||||
|
# .projKron <- function(dims) {
|
||||||
|
# ... # TODO: Implement this!
|
||||||
|
# }
|
||||||
|
|
||||||
|
#' @rdname matProj
|
||||||
|
#' @export
|
||||||
|
.projMaskedMean <- function(mask) {
|
||||||
|
force(mask)
|
||||||
|
function(A) {
|
||||||
|
`[<-`(matrix(0, nrow(A), ncol(A)), mask, mean(A[mask]))
|
||||||
|
}
|
||||||
|
}
|
|
@ -47,8 +47,10 @@ matrixImage <- function(A, add.values = FALSE,
|
||||||
x <- seq(1, ncol(A), by = 1)
|
x <- seq(1, ncol(A), by = 1)
|
||||||
y <- seq(1, nrow(A))
|
y <- seq(1, nrow(A))
|
||||||
if (axes && new.plot) {
|
if (axes && new.plot) {
|
||||||
axis(1, at = x - 0.5, labels = x, lwd = 0, lwd.ticks = 1)
|
if (!is.character(xlabels <- colnames(A))) { xlabels <- x }
|
||||||
axis(2, at = y - 0.5, labels = rev(y), lwd = 0, lwd.ticks = 1, las = 1)
|
if (!is.character(ylabels <- rownames(A))) { ylabels <- y }
|
||||||
|
axis(1, at = x - 0.5, labels = xlabels, lwd = 0, lwd.ticks = 1)
|
||||||
|
axis(2, at = y - 0.5, labels = rev(ylabels), lwd = 0, lwd.ticks = 1, las = 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Writes matrix values
|
# Writes matrix values
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
#' Moore-Penrose Pseudo inverse
|
||||||
|
#'
|
||||||
|
#' @param A any matrix
|
||||||
|
#'
|
||||||
|
#' @returns another matrix
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
pinv <- function(A) {
|
||||||
|
A <- as.matrix(A)
|
||||||
|
if (nrow(A) < ncol(A)) {
|
||||||
|
crossprod(A, matpow(tcrossprod(A), -1))
|
||||||
|
} else if (nrow(A) > ncol(A)) {
|
||||||
|
tcrossprod(matpow(crossprod(A), -1), A)
|
||||||
|
} else {
|
||||||
|
matpow(A, -1)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
#' Slice index selection
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' # Exquivalent to
|
||||||
|
#' array(A[slice.index(A, mode) == index], dim = dim(A)[-mode])
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
slice.select <- function(A, mode, index) {
|
||||||
|
arg <- rep("", length(dim(A)))
|
||||||
|
arg[mode] <- "i"
|
||||||
|
expr <- str2lang(paste0("A[", paste0(arg, collapse = ","), "]", collapse = ""))
|
||||||
|
slice <- eval(expr, list(i = index))
|
||||||
|
dim(slice) <- dim(A)[-mode]
|
||||||
|
slice
|
||||||
|
}
|
||||||
|
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
slice.expr <- function(A, mode, index = "i", drop = TRUE, nr.axis = length(dim(A))) {
|
||||||
|
str <- as.character(substitute(A))
|
||||||
|
arg <- rep("", nr.axis)
|
||||||
|
arg[mode] <- as.character(substitute(index))
|
||||||
|
str2lang(paste0(str, "[", paste0(arg, collapse = ","),
|
||||||
|
if (drop) "]" else ",drop=FALSE]", collapse = ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#' @export
|
||||||
|
slice.assign.expr <- function(obj, nr.axis) {
|
||||||
|
assign.call <- as.call(c(
|
||||||
|
list(`[<-`, substitute(obj)),
|
||||||
|
rep(list(alist(a = )$a), nr.axis - 1L), # replicate empty symbol
|
||||||
|
substitute(index), substitute(x)
|
||||||
|
))
|
||||||
|
function(i, val) {
|
||||||
|
eval(assign.call, envir = list(index = i, x = val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# n <- 1000
|
||||||
|
# p <- c(2, 4, 3)
|
||||||
|
# A <- array(seq_len(prod(n, p)), dim = c(p, n))
|
||||||
|
|
||||||
|
# mode <- 4
|
||||||
|
# index <- 7
|
||||||
|
|
||||||
|
# stopifnot(all.equal(
|
||||||
|
# A[, , , index],
|
||||||
|
# array(A[slice.index(A, mode) == index], dim = dim(A)[-mode])
|
||||||
|
# ))
|
||||||
|
# stopifnot(all.equal(
|
||||||
|
# A[, , , index],
|
||||||
|
# slice.select(A, mode, index)
|
||||||
|
# ))
|
||||||
|
|
||||||
|
# arg <- rep("", length(dim(A)))
|
||||||
|
# arg[mode] <- "i"
|
||||||
|
# `A[..., i]` <- str2lang(paste0("A[", paste0(arg, collapse = ","), "]", collapse = ""))
|
||||||
|
|
||||||
|
# microbenchmark::microbenchmark(
|
||||||
|
# A[, , , index],
|
||||||
|
# eval(`A[..., i]`, list(i = index)),
|
||||||
|
# slice.select(A, mode, index),
|
||||||
|
# array(A[slice.index(A, mode) == index], dim = dim(A)[-mode])
|
||||||
|
# )
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
#' Linear Equation Solver
|
||||||
|
#'
|
||||||
|
#' Using Lapack DGESV, similar to the base solve routine of R.
|
||||||
|
#'
|
||||||
|
#' @note for testing purposes
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
La.solve <- function(A, B = diag(nrow(A))) {
|
||||||
|
storage.mode(A) <- storage.mode(B) <- "double"
|
||||||
|
.Call("C_solve", A, as.matrix(B), PACKAGE = "tensorPredictors")
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
#' Generale a matrix of all permutations of `n` elements
|
||||||
|
permutations <- function(n) {
|
||||||
|
if (n <= 0) {
|
||||||
|
matrix(nrow = 0, ncol = 0)
|
||||||
|
} else if (n == 1) {
|
||||||
|
matrix(1)
|
||||||
|
} else {
|
||||||
|
sub.perm <- permutations(n - 1)
|
||||||
|
p <- nrow(sub.perm)
|
||||||
|
A <- matrix(NA, n * p, n)
|
||||||
|
for (i in 1:n) {
|
||||||
|
A[(i - 1) * p + 1:p, ] <- cbind(i, sub.perm + (sub.perm >= i))
|
||||||
|
}
|
||||||
|
A
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#' General symmetrization opperation for tensors (arrays) of equal dimensions
|
||||||
|
#'
|
||||||
|
#' @param A array of dimensions c(p, ..., p)
|
||||||
|
#'
|
||||||
|
#' @returns array of same dimensions as `A`
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
tsym <- function(A) {
|
||||||
|
stopifnot(all(dim(A) == nrow(A)))
|
||||||
|
|
||||||
|
if (is.matrix(A)) {
|
||||||
|
return(0.5 * (A + t(A)))
|
||||||
|
}
|
||||||
|
|
||||||
|
axis.perm <- permutations(length(dim(A)))
|
||||||
|
|
||||||
|
S <- array(0, dim(A))
|
||||||
|
for (i in seq_len(nrow(axis.perm))) {
|
||||||
|
S <- S + aperm(A, axis.perm[i, ])
|
||||||
|
}
|
||||||
|
|
||||||
|
S / nrow(axis.perm)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#' Genralized (pseudo) symmetrication for generel multi-dimensional arrays
|
||||||
|
sym <- function(A, FUN = `+`, scale = factorial(length(dim(A)))) {
|
||||||
|
FUN <- match.fun(FUN)
|
||||||
|
|
||||||
|
if (is.matrix(A) && (nrow(A) == ncol(A))) {
|
||||||
|
A <- FUN(A, t(A))
|
||||||
|
return(if (is.numeric(scale)) A / scale else A)
|
||||||
|
}
|
||||||
|
|
||||||
|
A.copy <- A
|
||||||
|
perm <- seq_along(dim(A))
|
||||||
|
while (length(pivot <- which(diff(perm) > 0))) {
|
||||||
|
pivot <- max(pivot)
|
||||||
|
successor <- max(which(perm[seq_along(perm) > pivot] > perm[pivot])) + pivot
|
||||||
|
perm[c(pivot, successor)] <- perm[c(successor, pivot)]
|
||||||
|
suffix <- seq(pivot + 1, length(perm))
|
||||||
|
perm <- c(perm[-suffix], perm[rev(suffix)])
|
||||||
|
|
||||||
|
modes <- which(perm != seq_along(perm))
|
||||||
|
sub.dimA <- dim(A)
|
||||||
|
sub.dimA[modes] <- min(dim(A)[modes])
|
||||||
|
sub.indices <- Map(seq_len, sub.dimA)
|
||||||
|
sub.selection <- do.call(`[`, c(list(A.copy), sub.indices, drop = FALSE))
|
||||||
|
|
||||||
|
|
||||||
|
sub.assign <- do.call(call, c(list("[<-", quote(quote(A))), sub.indices,
|
||||||
|
quote(quote(FUN(
|
||||||
|
do.call(`[`, c(list(A), sub.indices)),
|
||||||
|
aperm(do.call(`[`, c(list(A.copy), sub.indices, drop = FALSE)), perm)
|
||||||
|
)))
|
||||||
|
))
|
||||||
|
|
||||||
|
A <- eval(sub.assign)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is.numeric(scale)) A / scale else A
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
#include "det.h"
|
||||||
|
|
||||||
|
|
||||||
|
// For reference see: `det_ge_real` in "R-4.2.1/src/modules/lapack/Lapack.c"
|
||||||
|
// of the R source code.
|
||||||
|
double det(
|
||||||
|
/* dim */ const int dimA,
|
||||||
|
/* matrix */ const double* A, const int ldA,
|
||||||
|
double* work_mem, int* info
|
||||||
|
) {
|
||||||
|
// if working memory size query, return immediately
|
||||||
|
if (work_mem == NULL) {
|
||||||
|
*info = dimA * (dimA + 1);
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// determinant of "zero size" matrix is 1 (by definition)
|
||||||
|
if (dimA == 0) {
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy `A` to (continuous) `work_mem` cause `dgetrf` works "in place"
|
||||||
|
for (int i = 0; i < dimA; ++i) {
|
||||||
|
memcpy(work_mem + i * dimA, A + i * ldA, dimA * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
// L U factorization of `A`
|
||||||
|
int error = 0;
|
||||||
|
int* ipvt = (int*)(work_mem + dimA * dimA);
|
||||||
|
F77_CALL(dgetrf)(&dimA, &dimA, work_mem, &dimA, ipvt, &error);
|
||||||
|
|
||||||
|
// check if an error occured which is the case iff `dgetrf` gives a negative error
|
||||||
|
*info |= (error < 0) * error;
|
||||||
|
// in both cases we return zero (ether the determinant is zero or error occured)
|
||||||
|
if (error) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// res <- det(A) = sign(P) * prod(diag(L)) where P is the pivoting permutation
|
||||||
|
double res = 1.0;
|
||||||
|
for (int i = 0; i < dimA; ++i) {
|
||||||
|
res *= (ipvt[i] != (i + 1) ? -1.0 : 1.0) * work_mem[i * (dimA + 1)];
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* R bindong to `det`
|
||||||
|
*/
|
||||||
|
extern SEXP R_det(SEXP A) {
|
||||||
|
// check if A is a real valued square matrix
|
||||||
|
if (!Rf_isReal(A) || !Rf_isMatrix(A) || Rf_nrows(A) != Rf_ncols(A)) {
|
||||||
|
Rf_error("`A` must be a real valued squae matrix");
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocate working memory
|
||||||
|
int work_size;
|
||||||
|
(void)det(Rf_nrows(A), NULL, Rf_nrows(A), NULL, &work_size);
|
||||||
|
double* work_mem = (double*)R_alloc(work_size, sizeof(double));
|
||||||
|
|
||||||
|
// compute determinant (followed only by if-statement, no protection required)
|
||||||
|
int error = 0;
|
||||||
|
SEXP res = Rf_ScalarReal(
|
||||||
|
det(Rf_nrows(A), REAL(A), Rf_nrows(A), work_mem, &error)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
Rf_error("Encountered error code %d in `det`", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
#ifndef INCLUDE_GUARD_DET_H
|
||||||
|
#define INCLUDE_GUARD_DET_H
|
||||||
|
|
||||||
|
#include "R_api.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determinant of a matrix (or log of determinant)
|
||||||
|
*/
|
||||||
|
double det(
|
||||||
|
/* dim */ const int dimA,
|
||||||
|
/* matrix */ const double* A, const int ldA,
|
||||||
|
double* work_mem, int* info
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_DET_H */
|
|
@ -27,6 +27,8 @@
|
||||||
* [ * * * * * * ] [ * * * 17 * * ] [ * ]
|
* [ * * * * * * ] [ * * * 17 * * ] [ * ]
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
#include <float.h> // DBL_MAX
|
||||||
|
|
||||||
#include "R_api.h"
|
#include "R_api.h"
|
||||||
#include "bit_utils.h"
|
#include "bit_utils.h"
|
||||||
#include "int_utils.h"
|
#include "int_utils.h"
|
||||||
|
@ -121,7 +123,7 @@ double ising_m2_exact(const size_t dim, const double* params, double* M2) {
|
||||||
const double prob_X = exp(dot_X);
|
const double prob_X = exp(dot_X);
|
||||||
sum_0 += prob_X;
|
sum_0 += prob_X;
|
||||||
|
|
||||||
// Accumulate set bits probability for the first end second moment `E[X X']`
|
// Accumulate set bits probability for the first and second moment `E[X X']`
|
||||||
for (uint32_t Y = X; Y; Y &= Y - 1) {
|
for (uint32_t Y = X; Y; Y &= Y - 1) {
|
||||||
const int i = bitScanLS32(Y);
|
const int i = bitScanLS32(Y);
|
||||||
const int I = (i * (2 * dim - 1 - i)) / 2;
|
const int I = (i * (2 * dim - 1 - i)) / 2;
|
||||||
|
@ -137,7 +139,7 @@ double ising_m2_exact(const size_t dim, const double* params, double* M2) {
|
||||||
M2[i] *= prob_0;
|
M2[i] *= prob_0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return prob_0;
|
return log(prob_0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -162,7 +164,7 @@ double ising_m2_MC(
|
||||||
int* X = (int*)R_alloc(dim, sizeof(int));
|
int* X = (int*)R_alloc(dim, sizeof(int));
|
||||||
|
|
||||||
// Accumulator for Monte-Carlo estimate for zero probability `P(X = 0)`
|
// Accumulator for Monte-Carlo estimate for zero probability `P(X = 0)`
|
||||||
double accum = 0.0;
|
double accum = 0.0, max_mdot_X = -DBL_MAX;
|
||||||
|
|
||||||
// Create/Update R's internal PRGN state
|
// Create/Update R's internal PRGN state
|
||||||
GetRNGstate();
|
GetRNGstate();
|
||||||
|
@ -182,7 +184,11 @@ double ising_m2_MC(
|
||||||
dot_X += (X[i] & X[j]) ? params[I + j] : 0.0;
|
dot_X += (X[i] & X[j]) ? params[I + j] : 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
accum += exp(-dot_X);
|
if (-dot_X > max_mdot_X) {
|
||||||
|
accum *= exp(max_mdot_X + dot_X);
|
||||||
|
max_mdot_X = -dot_X;
|
||||||
|
}
|
||||||
|
accum += exp(-dot_X - max_mdot_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write R's internal PRNG state back (if needed)
|
// Write R's internal PRNG state back (if needed)
|
||||||
|
@ -190,11 +196,11 @@ double ising_m2_MC(
|
||||||
|
|
||||||
// Compute means from counts
|
// Compute means from counts
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
M2[i] = (double)counts[i] / (double)(nr_samples);
|
M2[i] = (double)counts[i] / (double)nr_samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prob. of zero event (Ising p.m.f. scaling constant)
|
// Prob. of zero event (Ising p.m.f. scaling constant)
|
||||||
return accum / (exp2((double)dim) * (double)nr_samples);
|
return log(accum) + max_mdot_X - log(2) * (double)dim - log((double)nr_samples);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +216,7 @@ typedef struct thrd_data {
|
||||||
const double* params; // Ising model parameters
|
const double* params; // Ising model parameters
|
||||||
int* X; // Working memory to store current binary sample
|
int* X; // Working memory to store current binary sample
|
||||||
uint32_t* counts; // (output) count of single and two way interactions
|
uint32_t* counts; // (output) count of single and two way interactions
|
||||||
double accum; // (output) Monte-Carlo accumulator for `P(X = 0)`
|
double log_sum_exp; // (output) Monte-Carlo inbetween value for `log(P(X = 0))`
|
||||||
} thrd_data_t;
|
} thrd_data_t;
|
||||||
|
|
||||||
// Worker thread function
|
// Worker thread function
|
||||||
|
@ -222,8 +228,8 @@ int thrd_worker(thrd_data_t* data) {
|
||||||
// Initialize counts to zero
|
// Initialize counts to zero
|
||||||
(void)memset(data->counts, 0, (dim * (dim + 1) / 2) * sizeof(uint32_t));
|
(void)memset(data->counts, 0, (dim * (dim + 1) / 2) * sizeof(uint32_t));
|
||||||
|
|
||||||
// Init Monte-Carlo estimate for zero probability `P(X = 0)`
|
// Accumulator for Monte-Carlo estimate for zero probability `P(X = 0)`
|
||||||
data->accum = 0.0;
|
double accum = 0.0, max_mdot_X = -DBL_MAX;
|
||||||
|
|
||||||
// Spawn Monte-Carlo Chains (one foe every sample)
|
// Spawn Monte-Carlo Chains (one foe every sample)
|
||||||
for (size_t sample = 0; sample < data->nr_samples; ++sample) {
|
for (size_t sample = 0; sample < data->nr_samples; ++sample) {
|
||||||
|
@ -240,8 +246,14 @@ int thrd_worker(thrd_data_t* data) {
|
||||||
dot_X += (X[i] & X[j]) ? data->params[I + j] : 0.0;
|
dot_X += (X[i] & X[j]) ? data->params[I + j] : 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
data->accum += exp(-dot_X);
|
if (-dot_X > max_mdot_X) {
|
||||||
}
|
accum *= exp(max_mdot_X + dot_X);
|
||||||
|
max_mdot_X = -dot_X;
|
||||||
|
}
|
||||||
|
accum += exp(-dot_X - max_mdot_X);
|
||||||
|
}
|
||||||
|
|
||||||
|
data->log_sum_exp = log(accum) + max_mdot_X;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -286,13 +298,19 @@ double ising_m2_MC_thrd(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulate worker results into first (tid = 0) worker counts result
|
// Accumulate worker results into first (tid = 0) worker counts result
|
||||||
// as well as accumulate the Monte-Carlo accumulators into one
|
// while determining the log sum exp maximum
|
||||||
double accum = threads_data[0].accum;
|
double max = threads_data[0].log_sum_exp;
|
||||||
for (size_t tid = 1; tid < nr_threads; ++tid) {
|
for (size_t tid = 1; tid < nr_threads; ++tid) {
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
counts[i] += counts[tid * len + i];
|
counts[i] += counts[tid * len + i];
|
||||||
}
|
}
|
||||||
accum += threads_data[tid].accum;
|
max = (max < threads_data[tid].log_sum_exp) ? threads_data[tid].log_sum_exp : max;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accum all `log(P(X = 0))` via "LogSumExp" trick into final result
|
||||||
|
double accum = 0;
|
||||||
|
for (size_t tid = 0; tid < nr_threads; ++tid) {
|
||||||
|
accum += exp(threads_data[tid].log_sum_exp - max);
|
||||||
}
|
}
|
||||||
|
|
||||||
// convert discreat counts into means
|
// convert discreat counts into means
|
||||||
|
@ -300,8 +318,8 @@ double ising_m2_MC_thrd(
|
||||||
M2[i] = (double)counts[i] / (double)nr_samples;
|
M2[i] = (double)counts[i] / (double)nr_samples;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prob. of zero event (Ising p.m.f. scaling constant)
|
// Log of Prob. of zero event (Ising p.m.f. scaling constant)
|
||||||
return accum / (exp2((double)dim) * (double)nr_samples);
|
return log(accum) + max - log(2) * (double)dim - log((double)nr_samples);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif /* !__STDC_NO_THREADS__ */
|
#endif /* !__STDC_NO_THREADS__ */
|
||||||
|
@ -363,9 +381,9 @@ extern SEXP R_ising_m2(
|
||||||
SEXP _M2 = PROTECT(Rf_allocVector(REALSXP, dim * (dim + 1) / 2));
|
SEXP _M2 = PROTECT(Rf_allocVector(REALSXP, dim * (dim + 1) / 2));
|
||||||
++protect_count;
|
++protect_count;
|
||||||
|
|
||||||
// asside computed zero event probability (inverse partition function), the
|
// asside computed log of zero event probability (log of recibrocal partition
|
||||||
// scaling factor for the Ising model p.m.f.
|
// function), the log scaling factor for the Ising model p.m.f.
|
||||||
double prob_0 = -1.0;
|
double log_prob_0 = -1.0;
|
||||||
|
|
||||||
if (use_MC) {
|
if (use_MC) {
|
||||||
// Convert and validate arguments for the Monte-Carlo methods
|
// Convert and validate arguments for the Monte-Carlo methods
|
||||||
|
@ -384,15 +402,15 @@ extern SEXP R_ising_m2(
|
||||||
|
|
||||||
if (nr_threads == 1) {
|
if (nr_threads == 1) {
|
||||||
// Single threaded Monte-Carlo method
|
// Single threaded Monte-Carlo method
|
||||||
prob_0 = ising_m2_MC(nr_samples, warmup, dim, params, REAL(_M2));
|
log_prob_0 = ising_m2_MC(nr_samples, warmup, dim, params, REAL(_M2));
|
||||||
} else {
|
} else {
|
||||||
// Multi-Threaded Monte-Carlo method if provided, otherwise use
|
// Multi-Threaded Monte-Carlo method if provided, otherwise use
|
||||||
// the single threaded version with a warning
|
// the single threaded version with a warning
|
||||||
#ifdef __STDC_NO_THREADS__
|
#ifdef __STDC_NO_THREADS__
|
||||||
Rf_warning("Multi-Threading NOT supported, using fallback.");
|
Rf_warning("Multi-Threading NOT supported, using fallback.");
|
||||||
prob_0 = ising_m2_MC(nr_samples, warmup, dim, params, REAL(_M2));
|
log_prob_0 = ising_m2_MC(nr_samples, warmup, dim, params, REAL(_M2));
|
||||||
#else
|
#else
|
||||||
prob_0 = ising_m2_MC_thrd(
|
log_prob_0 = ising_m2_MC_thrd(
|
||||||
nr_samples, warmup, nr_threads,
|
nr_samples, warmup, nr_threads,
|
||||||
dim, params, REAL(_M2)
|
dim, params, REAL(_M2)
|
||||||
);
|
);
|
||||||
|
@ -405,13 +423,13 @@ extern SEXP R_ising_m2(
|
||||||
}
|
}
|
||||||
|
|
||||||
// and call the exact method
|
// and call the exact method
|
||||||
prob_0 = ising_m2_exact(dim, params, REAL(_M2));
|
log_prob_0 = ising_m2_exact(dim, params, REAL(_M2));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set log-lokelihood as an attribute to the computed second moment
|
// Set log-lokelihood as an attribute to the computed second moment
|
||||||
SEXP _prob_0 = PROTECT(Rf_ScalarReal(prob_0));
|
SEXP _log_prob_0 = PROTECT(Rf_ScalarReal(log_prob_0));
|
||||||
++protect_count;
|
++protect_count;
|
||||||
Rf_setAttrib(_M2, Rf_install("prob_0"), _prob_0);
|
Rf_setAttrib(_M2, Rf_install("log_prob_0"), _log_prob_0);
|
||||||
|
|
||||||
// release SEPXs to the garbage collector
|
// release SEPXs to the garbage collector
|
||||||
UNPROTECT(protect_count);
|
UNPROTECT(protect_count);
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
#include "solve.h"
|
||||||
|
|
||||||
|
void solve(
|
||||||
|
/* dims */ const int dimA, const int nrhs,
|
||||||
|
/* matrix */ const double* A, const int ldA,
|
||||||
|
/* matrix */ const double* B, const int ldB,
|
||||||
|
/* matrix */ double* X, const int ldX,
|
||||||
|
double* work_mem, int* info
|
||||||
|
) {
|
||||||
|
// Compute required working memory size if requested
|
||||||
|
if (work_mem == NULL) {
|
||||||
|
*info = dimA * (dimA + 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy `A` to (continuous) working memory
|
||||||
|
for (int i = 0; i < dimA; ++i) {
|
||||||
|
memcpy(work_mem + i * dimA, A + i * ldA, dimA * sizeof(double));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy `B` to `X` or set `X` to identity
|
||||||
|
if (B == NULL) {
|
||||||
|
double* X_col = X;
|
||||||
|
for (int j = 0; j < dimA; ++j, X_col += ldX) {
|
||||||
|
for (int i = 0; i < dimA; ++i) {
|
||||||
|
*(X_col + i) = (i == j) ? 1.0 : 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < nrhs; ++i) {
|
||||||
|
memcpy(X + i * ldX, B + i * ldB, dimA * sizeof(double));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lapack routine DGESV to solve linear system A X = B which writes
|
||||||
|
// result into `A`, `B` which are copied into working memory and the result
|
||||||
|
// memory `X`
|
||||||
|
int error = 0;
|
||||||
|
F77_CALL(dgesv)(
|
||||||
|
/* dims */ &dimA, &nrhs,
|
||||||
|
/* matrix A */ work_mem, &dimA, /* [in,out] A -> P L U */
|
||||||
|
/* ipiv */ (int*)(work_mem + dimA * dimA), /* [out] */
|
||||||
|
/* matrix B */ X, &ldX, /* [in,out] B -> X */
|
||||||
|
&error /* [out] */
|
||||||
|
);
|
||||||
|
|
||||||
|
// update error flag
|
||||||
|
*info |= error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* R binding to `solve` which solves A X = B for X
|
||||||
|
*/
|
||||||
|
extern SEXP R_solve(SEXP A, SEXP B) {
|
||||||
|
// Check types
|
||||||
|
if (!(Rf_isReal(A) && Rf_isMatrix(A))
|
||||||
|
|| !(Rf_isReal(B) && Rf_isMatrix(B))) {
|
||||||
|
Rf_error("All arguments must be real valued matrices");
|
||||||
|
}
|
||||||
|
|
||||||
|
// check dimensions
|
||||||
|
if (Rf_nrows(A) != Rf_ncols(A)
|
||||||
|
|| Rf_ncols(A) != Rf_nrows(B)) {
|
||||||
|
Rf_error("Dimension missmatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate result matrix `X`
|
||||||
|
SEXP X = PROTECT(Rf_allocMatrix(REALSXP, Rf_nrows(B), Rf_ncols(B)));
|
||||||
|
|
||||||
|
// Allocate required working memory
|
||||||
|
int work_size = 0;
|
||||||
|
solve(
|
||||||
|
Rf_nrows(A), Rf_ncols(B),
|
||||||
|
NULL, Rf_nrows(A),
|
||||||
|
NULL, Rf_nrows(B),
|
||||||
|
NULL, Rf_nrows(X),
|
||||||
|
NULL, &work_size
|
||||||
|
);
|
||||||
|
double* work_mem = (double*)R_alloc(work_size, sizeof(double));
|
||||||
|
|
||||||
|
// Solve the system A X = B an write results into `X`
|
||||||
|
int error = 0;
|
||||||
|
solve(
|
||||||
|
Rf_nrows(A), Rf_ncols(B),
|
||||||
|
REAL(A), Rf_nrows(A),
|
||||||
|
REAL(B), Rf_nrows(B),
|
||||||
|
REAL(X), Rf_nrows(X),
|
||||||
|
work_mem, &error
|
||||||
|
);
|
||||||
|
|
||||||
|
// release `X` to the garbage collector
|
||||||
|
UNPROTECT(1);
|
||||||
|
|
||||||
|
// check error after unprotect
|
||||||
|
if (error) {
|
||||||
|
Rf_error("Solve ended with error code %d", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
return X;
|
||||||
|
}
|
|
@ -0,0 +1,19 @@
|
||||||
|
#ifndef INCLUDE_GUARD_SOLVE_H
|
||||||
|
#define INCLUDE_GUARD_SOLVE_H
|
||||||
|
|
||||||
|
#include "R_api.h"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Solves a linear equation system of the form
|
||||||
|
* A X = B
|
||||||
|
* for `X` where `A` is a `dimA` x `dimA` matrix and `B` is `dimA` x `nrhs`.
|
||||||
|
*/
|
||||||
|
void solve(
|
||||||
|
/* dims */ const int dimA, const int nrhs,
|
||||||
|
/* matrix */ const double* A, const int ldA,
|
||||||
|
/* matrix */ const double* B, const int ldB,
|
||||||
|
/* matrix */ double* X, const int ldX,
|
||||||
|
double* work_mem, int* info
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif /* INCLUDE_GUARD_SOLVE_H */
|
Loading…
Reference in New Issue