add: chess data example with Rchess

This commit is contained in:
Daniel Kapla 2023-12-11 14:29:51 +01:00
parent 4b4b30ceb0
commit 6792cf93a9
31 changed files with 6271 additions and 2 deletions

8
.gitignore vendored
View File

@ -43,10 +43,11 @@
*.idb
*.pdb
## R environment, data and pacakge build intermediate files/folders
## R environment, data and package build intermediate files/folders
# R Data Object files
*.Rds
*.rds
*.Rdata
# Example code in package build process
*-Ex.R
@ -110,7 +111,10 @@ simulations/
mlda_analysis/
References/
dataAnalysis/
dataAnalysis/*
!dataAnalysis/chess/
dataAnalysis/chess/*.fen
*.csv
*.csv.log

View File

@ -0,0 +1,26 @@
# Source dependency setup
INCLUDE = Rchess/inst/include/SchachHoernchen
SRC = $(notdir $(wildcard $(INCLUDE)/*.cpp))
OBJ = $(SRC:.cpp=.o)
# Compiler config
CC = g++
FLAGS = -I$(INCLUDE) -Wall -Wextra -Wpedantic -pedantic -pthread -O3 -march=native -mtune=native
CPPFLAGS = $(FLAGS) -std=c++17
LDFLAGS = $(FLAGS)
.PHONY: all clean
%.o: $(INCLUDE)/%.cpp
$(CC) $(CPPFLAGS) -c $< -o $(notdir $@)
pgn2fen.o: pgn2fen.cpp
$(CC) $(CPPFLAGS) -c $< -o $@
pgn2fen: pgn2fen.o $(OBJ)
$(CC) $(LDFLAGS) -o $@ $^
all: pgn2fen
clean:
rm -f *.out *.o *.h.gch pgn2fen

View File

@ -0,0 +1,15 @@
Package: Rchess
Type: Package
Title: Wrapper to the SchachHoernchen engine
Version: 1.0
Date: 2022-06-12
Author: loki
Maintainer: Your Name <your@email.com>
Description: Basic wrapper to the underlying C++ code of the SchachHoernchen
engine. Primarely intended to provide chess specific data processing.
Encoding: UTF-8
License: GPL (>= 2)
Imports: Rcpp (>= 1.0.8)
LinkingTo: Rcpp
SystemRequirements: c++17
RoxygenNote: 7.2.0

View File

@ -0,0 +1,4 @@
useDynLib(Rchess, .registration=TRUE)
importFrom(Rcpp, evalCpp)
exportPattern("^[[:alpha:]]+")
S3method(print, board)

View File

@ -0,0 +1,78 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
#' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
#' `gmlm_chess()` as data generator to provide random draws from a FEN data set
#' with scores filtered to be in in the range `score_min` to `score_max`.
#'
data.gen <- function(file, sample_size, score_min, score_max) {
.Call(`_Rchess_data_gen`, file, sample_size, score_min, score_max)
}
#' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
fen2int <- function(boards) {
.Call(`_Rchess_fen2int`, boards)
}
#' Reads lines from a text file with recycling.
#'
read.cyclic <- function(file, nrows = 1000L, skip = 100L, start = 1L, line_len = 64L) {
.Call(`_Rchess_read_cyclic`, file, nrows, skip, start, line_len)
}
#' Samples a legal move from a given position
sample.move <- function(pos) {
.Call(`_Rchess_sample_move`, pos)
}
#' Samples a random FEN (position) by applying `ply` random moves to the start
#' position.
#'
#' @param nr number of positions to sample
#' @param min_depth minimum number of random ply's to generate random positions
#' @param max_depth maximum number of random ply's to generate random positions
sample.fen <- function(nr, min_depth = 4L, max_depth = 20L) {
.Call(`_Rchess_sample_fen`, nr, min_depth, max_depth)
}
#' Converts a FEN string to a Board (position) or return the current internal state
board <- function(fen = "") {
.Call(`_Rchess_board`, fen)
}
print.board <- function(fen = "") {
invisible(.Call(`_Rchess_print_board`, fen))
}
print.moves <- function(fen = "") {
invisible(.Call(`_Rchess_print_moves`, fen))
}
print.bitboards <- function(fen = "") {
invisible(.Call(`_Rchess_print_bitboards`, fen))
}
position <- function(pos, moves, san = FALSE) {
.Call(`_Rchess_position`, pos, moves, san)
}
perft <- function(depth = 6L) {
invisible(.Call(`_Rchess_perft`, depth))
}
go <- function(depth = 6L) {
.Call(`_Rchess_go`, depth)
}
ucinewgame <- function() {
invisible(.Call(`_Rchess_ucinewgame`))
}
.onLoad <- function(libname, pkgname) {
invisible(.Call(`_Rchess_onLoad`, libname, pkgname))
}
.onUnload <- function(libpath) {
invisible(.Call(`_Rchess_onUnload`, libpath))
}

View File

@ -0,0 +1,82 @@
#ifndef INCLUDE_GUARD_RCHESS_TYPES
#define INCLUDE_GUARD_RCHESS_TYPES
#include <RcppCommon.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
#include "SchachHoernchen/uci.h"
namespace Rcpp {
template <> Move as(SEXP);
template <> SEXP wrap(const Move&);
template <> Board as(SEXP);
template <> SEXP wrap(const Board&);
} /* namespace Rcpp */
#include <Rcpp.h>
namespace Rcpp {
// Convert a coordinate encoded move string into a Move
template <>
Move as(SEXP obj) {
// parse (and validate) the move
bool parseError = false;
Move move = UCI::parseMove(Rcpp::as<std::string>(obj), parseError);
if (parseError) {
Rcpp::stop("Error parsing move");
}
return move;
}
// Convert a Move into an `R` character
template <>
SEXP wrap(const Move& move) {
return Rcpp::CharacterVector::create(UCI::formatMove(move));
}
// Convert a FEN string to a board
template <>
Board as(SEXP obj) {
bool parseError = false;
Board board;
std::string fen = Rcpp::as<std::string>(obj);
if (fen != "startpos") {
board.init(fen, parseError);
}
if (parseError) {
Rcpp::stop("Parsing FEN failed");
}
return board;
}
// Convert board to `R` board class (as character FEN string)
template <>
SEXP wrap(const Board& board) {
auto obj = Rcpp::CharacterVector::create(board.fen());
obj.attr("class") = "board";
return obj;
}
// Convert a character vector or list to a vector of Boards
template <>
std::vector<Board> as(SEXP obj) {
// Convert SEXP to be a vector of string
auto fens = Rcpp::as<std::vector<std::string>>(obj);
// Try to parse every string as a Board from a FEN
std::vector<Board> boards(fens.size());
for (int i = 0; i < fens.size(); ++i) {
bool parseError = false;
if (fens[i] != "startpos") {
boards[i].init(fens[i], parseError);
}
if (parseError) {
Rcpp::stop("Parsing FEN nr. %d failed", i + 1);
}
}
return boards;
}
} /* namespace Rcpp */
#endif /* INCLUDE_GUARD_RCHESS_TYPES */

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,224 @@
#ifndef INCLUDE_GUARD_BOARD_H
#define INCLUDE_GUARD_BOARD_H
#include <string>
#include <array>
#include <functional> // for std::hash
#include "types.h"
// Forward declarations
class Move;
class MoveList;
// pseudo-random-number-generator to fill Zobrist lookup hash table
static constexpr u64 rot(u64 val, int shift) {
return (val << shift) | (val >> (64 - shift));
}
static constexpr u64 random(u64& a, u64& b, u64& c, u64& d) {
u64 e = a - rot(b, 7);
a = b ^ rot(b, 13);
b = c + rot(d, 37);
c = d + e;
d = e + a;
return d;
}
/**
* BitBoard indexing;
* @verbatim
* rank rankIndex
* | index |
* v +-------------------------+ v
* 8 | 0 1 2 3 4 5 6 7 | 0
* 7 | 8 9 10 11 12 13 14 15 | 1
* 6 | 16 17 18 19 20 21 22 23 | 2
* 5 | 24 25 26 27 28 29 30 31 | 3
* 4 | 32 33 34 35 36 37 38 39 | 4
* 3 | 40 41 42 43 44 45 46 47 | 5
* 2 | 48 49 50 51 52 53 54 55 | 6
* 1 | 56 57 58 59 60 61 62 63 | 7
* +-------------------------+
* a b c d e f g h <- file
* 0 1 2 3 4 5 6 7 <- fileIndex
* @endverbatim
*/
class Board {
public:
Board()
: _castle{true, true, true, true}
, _enPassant{static_cast<Index>(64)}
, _halfMoveClock{0}
, _plyCount{1}
, _bitBoard{
0xFFFF000000000000, // white
0x000000000000FFFF, // black
0x00FF00000000FF00, // pawns
0x4200000000000042, // knights
0x2400000000000024, // bishops
0x8100000000000081, // rooks
0x0800000000000008, // queens
0x1000000000000010, // kings
} {
_hash = rehash(); // This makes it non default constructible
}; // but the main performance requirement is in the
// copy constructor and assignment operator
// copy constructor
Board(const Board& board)
: _castle{board._castle}
, _enPassant{board._enPassant}
, _halfMoveClock{board._halfMoveClock}
, _plyCount{board._plyCount}
, _hash{board._hash}
{
for (int i = 0; i < 8; i++)
_bitBoard[i] = board._bitBoard[i];
};
// copy assignment operator
Board& operator=(const Board& board) = default;
// accessor functions for not directly given values
enum piece piece(const Index) const;
enum piece color(const Index) const;
// accessor functions for internal states (private, only altered by Board)
u64 bb(enum piece piece) const { return _bitBoard[piece]; }
bool castleRight(enum piece color, enum location side) const;
Index enPassant() const { return _enPassant; }
unsigned halfMoveClock() const { return _halfMoveClock; }
unsigned plyCount() const { return _plyCount; }
u64 knightMoves(Index from) const { return knightMoveLookup[from]; }
u64 kingMoves(Index from) const { return kingMoveLookup[from]; }
// Ply Count interpretation of who's move it is
bool isWhiteTurn() const { return static_cast<bool>(_plyCount % 2); }
enum piece sideToMove() const {
return static_cast<enum piece>(!(_plyCount % 2));
}
// Formatting functions of pieces, squares, board, ...?
char formatPiece(Index fileIndex, Index rankIndex) const;
char formatPiece(Index squareIndex) const;
std::string format() const;
// Get board hash (Zobrist-hash), for complete computation see `rehash`
u64 hash() const { return _hash; };
// Get FEN string from board
std::string fen(const bool withCounts = true) const;
// Set board state given by FEN string
void init(const std::string& fen, bool& parseError);
// Complete (new) calculation of Zobrist hash for the current board position
u64 rehash() const;
// Validate a move (all moves from input checked before passed to `make`)
// and returns a legal (extended with all move info's) or non-move if case of
// illegal moves.
Move isLegal(const Move) const;
// Make a given move (assumed legal, otherwise undefined behavior)
void make(const Move); // passing move by value
// Generate all legal moves
void moves(MoveList&, bool = false) const;
void moves(MoveList&, const enum piece, bool = false) const;
// Check if current side to move is in check
bool isCheck() const;
// Check if the current position is a quiet position (no piece is attacked)
bool isQuiet() const;
// Checks if the position is a terminal position, meaning if the game ended
// by mate, stale mate or the 50 modes rule. Three-Fold repetition is NOT
// checked, therefore a seperate game history is required which the board
// does NOT track.
bool isTerminal() const;
// Checks if there is sufficient mating material on the board, meaning if it
// possible for any side to deliver a check mate. More specifically, it
// checks if the pieces on the board are KK, KNK or KBK.
bool isInsufficient() const;
// TODO: validate further possibilities like KBKB for two same color bishobs?!
// Board evaluation, gives a heuristic score of the current board position
// viewed from whites side
Score evalMaterial(enum piece) const;
Score evalPawns(enum piece) const;
Score evalKingSafety(enum piece) const;
Score evalRooks(enum piece) const;
Score evaluate() const;
// Static Exchange Evaluation
Score see(const Move) const;
// Subroutines for move generation, ...
u64 attacks(const enum piece, const u64) const;
u64 pinned(const enum piece, const u64, const u64) const;
u64 checkers(const enum piece, const u64, const u64) const;
// Equality operator used for comparing boards (exact equality in contrast
// to hash equality)
friend bool operator==(const Board& lhs, const Board& rhs) {
bool isEqual = true;
for (int i = 0; i < 8; i++) {
isEqual &= (lhs._bitBoard[i] == rhs._bitBoard[i]);
}
return isEqual
&& (lhs._castle.whiteKingSide == rhs._castle.whiteKingSide)
&& (lhs._castle.whiteQueenSide == rhs._castle.whiteQueenSide)
&& (lhs._castle.blackKingSide == rhs._castle.blackKingSide)
&& (lhs._castle.blackQueenSide == rhs._castle.blackQueenSide)
&& (lhs._enPassant == rhs._enPassant);
}
friend bool operator!=(const Board& lhs, const Board& rhs) {
return !(lhs == rhs);
}
private:
struct {
bool whiteKingSide : 1;
bool whiteQueenSide : 1;
bool blackKingSide : 1;
bool blackQueenSide : 1;
} _castle;
Index _enPassant;
unsigned _halfMoveClock;
unsigned _plyCount;
u64 _hash;
u64 _bitBoard[8];
// Lookup table for "random" hash value for piece-square, ... used in
// Zobrist board hashing.
static constexpr std::array<u64, 781> hashTable = []() {
std::array<u64, 781> table{};
u64 a = 0xABCD1729, b = 0, c = 0, d = 0;
for (int i = -11; i < 781; i++) {
u64 rand = random(a, b, c, d);
if (i >= 0) table[i] = rand;
}
return table;
}();
// Lookup helper for Zobrist hash lookup table
// hash index for 6 pieces of both colors for all squares
constexpr Index hashIndex(enum piece piece, enum piece color, Index sq) const {
return 12 * sq + (6 * (color == black)) + piece - 2;
}
// hash index the e.p. target square (file)
constexpr Index hashIndex(Index sq) const {
return 768 + (sq % 8);
}
enum piece sqPiece(const u64) const;
};
// Inject specialization hash<Board> into standard name space
namespace std {
template <>
struct hash<Board> {
using result_type = u64;
u64 operator()(const Board& board) const noexcept {
return board.hash();
};
};
}
#endif /* INCLUDE_GUARD_BOARD_H */

View File

@ -0,0 +1,150 @@
#ifndef UCI_GUARD_HASHTABLE_H
#define UCI_GUARD_HASHTABLE_H
#include "types.h"
#include <utility> // std::pair
#include <functional> // std::hash
#include <type_traits>
#include <cassert>
// Default policy; replace everything with the newest entry
template <typename T>
struct policy {
constexpr bool operator()(const T&, const T&) {
return true;
}
};
template <
typename Key,
typename Entry,
typename Policy = policy<Entry>,
typename Hashing = std::hash<Key>
>
class HashTable {
public:
using Hash = typename Hashing::result_type;
using Line = typename std::pair<Hash, Entry>;
HashTable() : _size{0}, _occupied{0}, _table{nullptr}, _hash{}, _policy{} { };
HashTable(Index size)
: _size{size}
, _occupied{0}
, _table{new (std::nothrow) Line[size]}
, _hash{}
, _policy{}
{
if (_table) {
for (Index i = 0; i < _size; i++) {
_table[i].first = static_cast<Hash>(0);
}
} else {
_size = 0;
}
};
~HashTable() { if (_table) { delete[] _table; } }
Index size() const { return _size; }
Index used() const { return 1000 * _occupied / _size; }
bool reserve(Index size) noexcept {
_occupied = 0;
if (size != _size) {
if (_table) { delete[] _table; }
_table = new (std::nothrow) Line[size];
}
if (_table) {
_size = size;
for (Index i = 0; i < _size; i++) {
_table[i].first = static_cast<Hash>(0);
}
return false;
} else {
_size = 0;
return true;
}
}
void clear() noexcept {
_occupied = 0;
for (Index i = 0; i < _size; i++) {
_table[i].first = static_cast<Hash>(0);
}
}
void erase() {
_occupied = 0;
_size = 0;
if (_table) { delete[] _table; }
_table = nullptr;
}
void insert(const Key& key, const Entry& entry) {
assert(0 < _size);
Hash hash = _hash(key);
Index index = static_cast<Index>(hash) % _size;
if (_table[index].first == static_cast<Hash>(0)) {
_occupied++;
_table[index] = Line(hash, entry);
}
if (_policy(_table[index].second, entry)) {
_table[index] = Line(hash, entry);
}
}
struct Iter {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Line;
using pointer = value_type*;
using reference = Entry&;
Iter(pointer ptr) : _ptr{ptr} { };
const Hash& hash() const { return _ptr->first; };
const Entry& value() const { return _ptr->second; };
// dereference operators
reference operator*() { return _ptr->second; };
pointer operator->() { return _ptr; };
// comparison operators
friend bool operator==(const Iter& lhs, const Iter& rhs) {
return lhs._ptr == rhs._ptr;
};
friend bool operator!=(const Iter& lhs, const Iter& rhs) {
return lhs._ptr != rhs._ptr;
};
private:
pointer _ptr;
};
Iter begin() noexcept { return Iter(_table); }
Iter end() noexcept { return Iter(_table + _size); }
Iter find(const Key& key) noexcept {
if (_size) {
Hash hash = _hash(key);
if (hash == static_cast<Hash>(0)) {
return Iter(_table + _size);
}
Index index = static_cast<Index>(hash) % _size;
if (_table[index].first == hash) {
return Iter(_table + index);
}
}
return Iter(_table + _size);
}
private:
Index _size;
Index _occupied;
Line* _table;
Hashing _hash;
Policy _policy;
};
#endif /* UCI_GUARD_HASHTABLE_H */

View File

@ -0,0 +1,305 @@
#ifndef INCLUDE_GUARD_MOVE_H
#define INCLUDE_GUARD_MOVE_H
#include <iterator>
#include <algorithm>
#include <cassert>
#include "utils.h"
#include "types.h"
class Move {
public:
using base_type = uint64_t;
private:
base_type _bits;
public:
Move() = default; // required by MoveList
// Reverse of cast to base_type
Move(base_type bits) : _bits{bits} { };
explicit Move(Index from, Index to)
: Move(from, to, static_cast<enum piece>(0)) { };
explicit Move(Index from, Index to, enum piece promotion)
: Move(static_cast<enum piece>(0), static_cast<enum piece>(0), from, to,
static_cast<enum piece>(0), promotion) { };
explicit Move(enum piece color, enum piece piece, Index from, Index to,
enum piece victim)
: Move(color, piece, from, to, victim, static_cast<enum piece>(0)) { };
// General Move constructor
explicit Move(enum piece color, enum piece piece, Index from, Index to,
enum piece victim, enum piece promotion)
: _bits{static_cast<base_type>(
( color << 21)
| ( victim << 18)
| ( piece << 15)
| (promotion << 12)
| ( to << 6)
| from
)}
{
assert(((color == white)
|| (color == black)));
assert(((victim == knight)
|| (victim == bishop)
|| (victim == rook)
|| (victim == queen)
|| (victim == pawn)
|| !victim));
assert(((piece == knight)
|| (piece == bishop)
|| (piece == rook)
|| (piece == queen)
|| (piece == pawn)
|| (piece == king)
|| !piece));
assert(((promotion == knight)
|| (promotion == bishop)
|| (promotion == rook)
|| (promotion == queen)
|| !promotion));
assert(( to < 64));
assert((from < 64));
};
Index from() const { return _bits & 63U; }
Index to() const { return (_bits >> 6) & 63U; }
enum piece promote() const {
return static_cast<enum piece>((_bits >> 12) & 7U);
}
enum piece piece() const {
return static_cast<enum piece>((_bits >> 15) & 7U);
}
enum piece victim() const {
return static_cast<enum piece>((_bits >> 18) & 7U);
}
enum piece color() const {
return static_cast<enum piece>((_bits >> 21) & 7U);
}
uint32_t score() const {
return static_cast<uint32_t>(_bits >> 32);
}
static constexpr uint32_t killerScore[2] = { 4096U, 4095U };
void setScore(const uint32_t val) {
constexpr base_type mask = (1ULL << 22) - 1ULL;
_bits = (_bits & mask) | (static_cast<base_type>(val) << 32);
}
uint32_t calcScore() const {
const uint32_t mvv_lva = ((victim() << 3) + (7U - piece()));
const bool winning = victim() > piece();
// add 128 to ensure the PST values are positive
const Index t = color() == white ? to() : 63 - to();
const Index f = color() == white ? from() : 63 - from();
const uint32_t pst = pieceSquareTables[piece()][t] - pieceSquareTables[piece()][f] + 128;
return ((static_cast<bool>(victim()) * mvv_lva) << (14 + winning * 6)) + pst;
}
// Allows to be cast to the base type (as number)
operator base_type() { return _bits; }
bool operator!() const { return this->from() == this->to(); }
// Comparison of <from><to>[<promotion>] part of the move. Everything else
// is redundent and allows simple parse move routine (use Board::isLegal
// for move legality test and augmentation)
bool operator==(const Move& rhs) const {
constexpr base_type mask = (1ULL << 22) - 1ULL;
return (_bits & mask) == (rhs._bits & mask);
}
bool operator!=(const Move& rhs) const {
constexpr base_type mask = (1ULL << 22) - 1ULL;
return (_bits & mask) != (rhs._bits & mask);
}
// default sort order is decreasing (for scored moves the best move first)
friend bool operator<(const Move& lhs, const Move& rhs) {
return lhs._bits > rhs._bits;
}
};
class MoveList {
private:
static constexpr unsigned _max_size = 256;
unsigned _size;
Move _moves[_max_size]; // apparently, 218 are the max nr. of moves
public:
MoveList() : _size{0} { };
// Copy Constructor
MoveList(const MoveList& moveList) : _size{moveList._size} {
assert(_size <= _max_size);
std::copy(moveList._moves, moveList._moves + _size, _moves);
}
// Copy Assignement Operator
MoveList& operator=(const MoveList&) = default;
unsigned size() const { return _size; };
bool empty() const { return !static_cast<bool>(_size); };
static constexpr unsigned max_size() { return _max_size; };
void clear() { _size = 0; };
Move* data() { return _moves; };
const Move* data() const { return _moves; };
template <typename... Args>
void emplace_back(Args&&... args) {
assert(_size < _max_size);
_moves[_size++] = Move(std::forward<Args>(args)...);
}
void push_back(const Move& move) {
assert(_size < _max_size);
_moves[_size++] = move;
}
void append(const MoveList& list) {
assert((_size + list._size) < _max_size);
for (unsigned i = 0; i < list._size; i++) {
_moves[_size++] = list._moves[i];
}
}
bool contains(const Move& move) const {
for (unsigned i = 0; i < _size; i++) {
if (_moves[i] == move) {
return true;
}
}
return false;
}
/**
* Searches for an occurence of `move` in the list and places the `move` as
* the first element in the list. If the provided `move` is not an element
* no opperation is performed and `false` will be returned.
*
* @param move a move to be searched for and placed as first element.
* @return `true` if `move` is found and placed as first element,
* `false` otherwise.
*/
bool move_front(const Move& move) {
for (unsigned i = 0; i < _size; i++) {
if (_moves[i] == move) {
std::swap(_moves[0], _moves[i]);
return true;
}
}
return false;
}
Move& operator[](unsigned i) { return _moves[i]; }
const Move& operator[](unsigned i) const { return _moves[i]; }
struct MoveIter {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = Move;
using pointer = value_type*;
using reference = value_type&;
MoveIter(pointer ptr) : _ptr{ptr} { };
// Dereference operators
reference operator*() { return *_ptr; };
pointer operator->() { return _ptr; };
// Arithmetic operators
MoveIter& operator++() { _ptr++; return *this; };
MoveIter operator++(int) {
MoveIter copy = *this; ++(*this); return copy;
};
MoveIter& operator--() { _ptr--; return *this; };
MoveIter operator--(int) {
MoveIter copy = *this; --(*this); return copy;
};
MoveIter& operator+=(int i) {
_ptr += i; return *this;
};
friend MoveIter operator+(MoveIter it, int i) { return (it += i); };
MoveIter& operator-=(int i) {
_ptr -= i; return *this;
};
friend MoveIter operator-(MoveIter it, int i) { return (it -= i); };
friend difference_type operator-(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr - rhs._ptr;
};
// Comparison operators
friend bool operator==(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr == rhs._ptr;
};
friend bool operator<=(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr <= rhs._ptr;
};
friend bool operator>=(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr >= rhs._ptr;
};
friend bool operator!=(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr != rhs._ptr;
};
friend bool operator<(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr < rhs._ptr;
};
friend bool operator>(const MoveIter& lhs, const MoveIter& rhs) {
return lhs._ptr > rhs._ptr;
};
private:
pointer _ptr;
};
MoveIter begin() { return MoveIter(_moves); };
MoveIter end() { return MoveIter(_moves + _size); };
struct ConstMoveIter {
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = const Move;
using pointer = const value_type*;
using reference = const value_type&;
ConstMoveIter(pointer ptr) : _ptr{ptr} { };
// Dereference operators
reference operator*() { return *_ptr; };
pointer operator->() { return _ptr; };
// Arithmetic operators
ConstMoveIter& operator++() { _ptr++; return *this; };
ConstMoveIter operator++(int) {
ConstMoveIter copy = *this; ++(*this); return copy;
};
ConstMoveIter& operator--() { _ptr--; return *this; };
ConstMoveIter operator--(int) {
ConstMoveIter copy = *this; --(*this); return copy;
};
ConstMoveIter& operator+=(int i) {
_ptr += i; return *this;
};
// Comparison operators
friend bool operator==(const ConstMoveIter& lhs, const ConstMoveIter& rhs) {
return lhs._ptr == rhs._ptr;
};
friend bool operator!=(const ConstMoveIter& lhs, const ConstMoveIter& rhs) {
return lhs._ptr != rhs._ptr;
};
private:
pointer _ptr;
};
ConstMoveIter begin() const { return ConstMoveIter(_moves); };
ConstMoveIter end() const { return ConstMoveIter(_moves + _size); };
};
#endif /* INCLUDE_GUARD_MOVE_H */

View File

@ -0,0 +1,63 @@
#ifndef INCLUDE_GUARD_BIT_H
#define INCLUDE_GUARD_BIT_H
#include <cstddef>
// Polyfill for `<bit>`, it compiled with proper standard this falls back to
// the STL implementation. Only provides functionality if not supported.
// TODO: extend to others than gcc (and maybe clang?)
// Inject polyfill for `std::endian` (since C++20)
namespace std {
enum class endian {
little = __ORDER_LITTLE_ENDIAN__,
big = __ORDER_BIG_ENDIAN__,
native = __BYTE_ORDER__
};
// Polyfill `std::byteswap` (since C++23)
template <typename T>
constexpr T byteswap(T x) noexcept {
#ifdef __GNUC__
if constexpr (sizeof(T) == 8) {
uint64_t swapped = __builtin_bswap64(*reinterpret_cast<uint64_t*>(&x));
return *reinterpret_cast<double*>(&swapped);
} else if constexpr (sizeof(T) == 4) {
uint32_t swapped = __builtin_bswap32(*reinterpret_cast<uint32_t*>(&x));
return *reinterpret_cast<double*>(&swapped);
} else if constexpr (sizeof(T) == 2) {
uint16_t swapped = __builtin_bswap16(*reinterpret_cast<uint16_t*>(&x));
return *reinterpret_cast<double*>(&swapped);
} else {
static_assert(sizeof(T) == 1);
return x;
}
#else
if constexpr (sizeof(T) == 8) {
uint64_t swapped = *reinterpret_cast<uint64_t*>(&x);
swapped = (((swapped & static_cast<uint64_t>(0xFF00FF00FF00FF00ULL)) >> 8)
| ((swapped & static_cast<uint64_t>(0x00FF00FF00FF00FFULL)) << 8));
swapped = (((swapped & static_cast<uint64_t>(0xFFFF0000FFFF0000ULL)) >> 16)
| ((swapped & static_cast<uint64_t>(0x0000FFFF0000FFFFULL)) << 16));
swapped = ((swapped >> 32) | (swapped << 32));
return *reinterpret_cast<T*>(&swapped);
} else if constexpr (sizeof(T) == 4) {
uint32_t swapped = *reinterpret_cast<uint32_t*>(&x);
swapped = (((swapped & static_cast<uint32_t>(0xFF00FF00)) >> 8)
| ((swapped & static_cast<uint32_t>(0x00FF00FF)) << 8));
swapped = ((swapped >> 16) | (swapped << 16));
return *reinterpret_cast<T*>(&swapped);
} else if constexpr (sizeof(T) == 2) {
uint16_t swapped = *reinterpret_cast<uint16_t*>(&x);
swapped = ((swapped << 8) | (swapped >> 8));
return *reinterpret_cast<T*>(&swapped);
} else {
static_assert(sizeof(T) == 1);
return x;
}
#endif
}
} /* namespace std */
#endif /* INCLUDE_GUARD_BIT_H */

View File

@ -0,0 +1,17 @@
#ifndef INCLUDE_GUARD_NULLSTREAM_H
#define INCLUDE_GUARD_NULLSTREAM_H
#include <ostream>
class nullstream : public std::ostream {
public:
nullstream() : std::ostream(nullptr) { }
nullstream(const nullstream&) : std::ostream(nullptr) { }
template <typename T>
nullstream& operator<<(const T& rhs) { return *this; }
nullstream& operator<<(std::ostream& (*fn)(std::ostream&)) { return *this; }
};
#endif /* INCLUDE_GUARD_NULLSTREAM_H */

View File

@ -0,0 +1,700 @@
#include <iostream>
#include <vector>
#include <algorithm>
#include <thread>
#include <atomic>
#include <chrono>
#include <type_traits>
#include "types.h"
#include "Board.h"
#include "Move.h"
#include "uci.h"
#include "syncstream.h"
#include "search.h"
#include "HashTable.h"
namespace Search {
// Node of a transposition table and its type
enum TTType { exact, upper_bound, lower_bound, unknown };
struct TTEntry {
enum TTType type;
unsigned age;
int depth;
Score score;
Move move;
};
// Transposition Table collision policy
struct TTPolicy {
bool operator()(const TTEntry& entry, const TTEntry& candidate) {
return (entry.age < candidate.age)
|| ((entry.type == candidate.type) && (entry.depth <= candidate.depth))
|| (candidate.type == exact);
}
};
// Implementation specific global search state, workers and concurrency controls
State state;
HashTable<Board, TTEntry, TTPolicy> tt;
std::atomic<bool> isRunning;
std::vector<std::thread> workers;
// Set initial static age of search
template <typename Evaluator>
unsigned PVS<Evaluator>::_age = 0;
void init(unsigned ttSize) {
// Init internal state
state.depth = 16; // default max search depth
state.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
state.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
state.winc = 0;
state.binc = 0;
state.movetime = 0;
// Init Transposition Table
constexpr unsigned ttLineSize = sizeof(HashTable<Board, TTEntry, TTPolicy>::Line);
// Reserve/resize tt to ttSize in MB of memory
if (tt.reserve(ttSize * 1024U * 1024U / ttLineSize)) {
// Report error of reserve failed
osyncstream(cout) << "info string error Unable to resize "
"Transposition Table to " << ttSize << "MB" << std::endl;
}
}
double ttSize() {
// Init Transposition Table
constexpr unsigned ttLineSize = sizeof(HashTable<Board, TTEntry, TTPolicy>::Line);
// compute size in MB from number of lines (size)
return static_cast<double>(tt.size() * ttLineSize) / 1048576.0;
}
void newgame() {
// Reset internal state
state.depth = 16; // default max search depth
state.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
state.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
state.winc = 0;
state.binc = 0;
state.movetime = 0;
// Delete all data from TT table
tt.clear();
}
// Perft implementation
unsigned perft_subroutine(Board& board, int depth) {
if (depth <= 0) { return 1; }
Index nodeCount = 0;
MoveList searchmoves;
board.moves(searchmoves);
if (depth == 1) {
nodeCount = searchmoves.size();
} else {
Board boardCopy = board;
for (Move move : searchmoves) {
board.make(move);
nodeCount += perft_subroutine(board, depth - 1);
board = boardCopy; // unmake move
}
}
return nodeCount;
}
/**
* Prints for all moves in passed move list (assuming legal moves) number of
* `depth - 1` performance test node count of this child node count.
*
* @param board Reference root node board position.
* @param depth search depth.
* @param moveList List of legal moves at current board position to search.
* If moveList is empty, search all legal moves (Note: In this case, the
* moveList is filled with all legal moves manipulating the list in place).
*
* @example Output generated from start position for depth `5` with moveList
* containing the moves "d4", "Nf3" and "Nc3".
* position startpos
* go perft 5 searchmoves d2d4 g1f3 b1c3
* d2d4: 361790
* g1f3: 233491
* b1c3: 234656
*
* Nodes searched: 829937
*/
void perft(Board& board, int depth, MoveList& searchmoves) {
if (depth <= 0) {
osyncstream(cout) << "Nodes searched: 0" << std::endl;
return;
}
if (!searchmoves.size()) {
board.moves(searchmoves);
}
Index totalCount = 0;
Board boardCopy = board;
for (Move move : searchmoves) {
// check if shutdown/stopping is requested
if (!isRunning.load(std::memory_order_relaxed)) { break; }
// continue traversing
board.make(move);
Index nodeCount;
nodeCount = perft_subroutine(board, depth - 1);
totalCount += nodeCount;
board = boardCopy; // unmake move
// report moves of node
osyncstream(cout) << move << ": " << nodeCount << std::endl;
}
osyncstream(cout)
<< std::endl << "Nodes searched: " << totalCount << std::endl;
}
/**
* Specialization of `perft(Board&, int, MoveList&)` for all legal moves.
*/
void perft(Board& board, int depth) {
MoveList moveList; // init empty move list -> perft for all legal moves
perft(board, depth, moveList);
}
/**
* Search initialization
*/
template <typename Evaluator>
PVS<Evaluator>::PVS(const std::vector<Board>& game, const State& config)
: _root(game.back())
// Restrict max search depth by the principal variation capacity
// Upper depth bound allows more than 200 plys which is enough.
, _max_depth{std::min(config.depth, static_cast<int>(MoveList::max_size()))}
// Init selective depth (max reached pvs + qsearch depth, max search ply)
, _seldepth{0}
// Set visited nodes to zero
, _nodes{0}
// break condition flag
, _isStopped{false}
// search start time for time dependend break condition and reporting
, _start_time{clock::now()}
, _searchmoves(config.searchmoves)
{
// Fill search moves with all legal moves if no moves are specified
if (_searchmoves.empty()) {
_root.moves(_searchmoves);
}
// Increment age (for TT entries)
++_age;
// Time control; set time to stop searching (ignoring potental loss of 2 ms)
unsigned half_move_time;
if (config.movetime) {
half_move_time = config.movetime / 2U;
} else {
if (_root.isWhiteTurn()) {
half_move_time = std::max(config.wtime / 60U, 50U);
} else {
half_move_time = std::max(config.btime / 60U, 50U);
}
}
_mid_time = _start_time + milliseconds( half_move_time);
_end_time = _start_time + milliseconds(2U * half_move_time);
// fill repitition detection history
_historySize = _root.halfMoveClock() + 1 < game.size()
? _root.halfMoveClock() + 1
: game.size();
auto gameIter = game.rbegin();
for (size_t i = _historySize; i-- > 0; ) {
_history[i] = (*gameIter++).hash();
}
}
/**
* Principal Variation Search at the Root Position
*/
template <typename Evaluator>
Score PVS<Evaluator>::operator()() {
using std::chrono::duration_cast;
// Color keeps track of static evaluation sign for the current player
// in the negated maximization (negaMax) framework
Score color = _root.isWhiteTurn() ? +1 : -1;
// Tracks principal variation
MoveList pv_line;
// set killers all to !move (TODO: is this required?)
std::fill_n(_killers[0], MoveList::max_size(), Move(0));
std::fill_n(_killers[1], MoveList::max_size(), Move(0));
// set history heuristic to zero // TODO: can I replace that by zero-initialization?!
for (Index p = 0; p < 6; ++p) {
std::fill_n(_historyTable[0][p], 64, 0);
std::fill_n(_historyTable[1][p], 64, 0);
}
// Iterative deepening
Score prevScore = 0;
for (int depth = 1; depth <= _max_depth; depth++) {
// Clear temporary pv line
pv_line.clear();
// Aspiration window search
Score alpha, score, beta;
Score delta = 34;
// Set initial aspriation window bounds (shallow searches -> full width)
if (depth < 3) {
alpha = limits<Score>::lower();
beta = limits<Score>::upper();
} else {
alpha = prevScore - delta;
beta = prevScore + delta;
}
while (true) {
// search current depth and aspiration window
score = pvs(_root, depth, 0, alpha, beta, color, pv_line);
// Increase aspiration window
delta += 2 * delta + 5;
// Increase window (iff fail high/low)
if (score <= alpha) {
alpha = std::max(prevScore - delta, limits<Score>::lower() - 1);
} else if (beta <= score) {
beta = std::max(prevScore + delta, limits<Score>::upper() + 1);
} else {
break;
}
}
// Getting deeper, set previous iteration score in iterative deepening
prevScore = score;
// Check search break condition (before updating the pv and reporting
// partial results)
if (_isStopped
|| !isRunning.load(std::memory_order_relaxed)
|| (_end_time < clock::now())) {
break;
}
// Copy principal variation
_pv.clear();
_pv.append(pv_line);
auto now = clock::now();
// Time spend in search till now (only for reporting)
auto duration = duration_cast<milliseconds>(now - _start_time);
// Check if found mate and report mate instead of centi pawn score
if (limits<Score>::upper() <= (std::abs(score) + depth)) {
// Note: +-1 for converting 0-indexed ply count to a 1-indexed
// move count
if (score < 0) {
score = limits<Score>::lower() - score - 1;
} else {
score = limits<Score>::upper() - score + 1;
}
// ply to move count
score /= 2;
// Report search stats and mate in score moves
osyncstream(cout)
<< "info depth " << depth
<< " seldepth " << _seldepth
<< " score mate " << score
<< " time " << duration.count()
<< " nodes " << _nodes
<< " pv " << _pv
<< std::endl;
// stop iterative deepening, found check mate -> no need to continue
break;
} else {
// Report search progress
osyncstream(cout)
<< "info depth " << depth
<< " seldepth " << _seldepth
<< " score cp " << score
<< " time " << duration.count()
<< " nodes " << _nodes
<< " pv " << _pv
<< std::endl;
}
// Check current time againt search mid time, iff passed the mid point
// a deeper iteration in the iterative deepening will most likely not
// finish. Therefore, stop and save the time for later.
if (_mid_time <= now) { break; }
}
if (_pv.size()) {
auto out = osyncstream(cout);
out << "bestmove " << _pv[0];
if (1 < _pv.size()) {
out << " ponder " << _pv[1];
}
out << std::endl;
} else {
// If there is no move in _pv (can happen in extreme short time control)
// just make a random (first) move, better than failing.
_root.moves(_pv); // _pv is empty
if (_pv.size()) {
osyncstream(cout) << "bestmove " << _pv[0] << std::endl;
} else {
osyncstream(cout) << "bestmove !move" << std::endl;
}
}
// Before leaving search, update search state (remaining time)
auto duration = duration_cast<milliseconds>(clock::now() - _start_time);
if (_root.isWhiteTurn()) {
state.wtime += state.winc - (duration.count() + 1);
} else {
state.btime += state.binc - (duration.count() + 1);
}
// Reset running flag
isRunning.store(false, std::memory_order_relaxed);
return (_root.isWhiteTurn() ? 1 : -1) * prevScore;
}
template <class Evaluator>
Score PVS<Evaluator>::qSearch(const PVS<Evaluator>::Position& pos, int ply,
Score alpha, Score beta, Score color
) {
// Track selective search depth counter (reported while searching)
_seldepth = std::max(_seldepth, ply);
// Static evaluation
Score score = color * pos.eval();
bool capturesOnly = true;
// Check if in check, if in check generate all moves (check evations).
// Otherwise, check if not moving at all (stand pat) is already too good.
if (pos.isCheck()) {
capturesOnly = false;
} else if (beta <= score) {
++_nodes;
return beta;
}
// Generate all captures (captures only, if not in check)
MoveList moveList;
pos.moves(moveList, capturesOnly);
// Check for terminal node
if (moveList.empty()) {
++_nodes;
// If there are no captures left, return static evaluation score
if (capturesOnly) {
return score;
// When generated all moves (in case of check evations), the fact that
// there are no moves means mate.
} else {
return color * (pos.isWhiteTurn() ? +1 : -1)
* (limits<Score>::lower() + ply);
}
}
// Ensure capture is better than stand pat
alpha = std::max(alpha, score);
// if (capturesOnly) {
// for (Move& move : moveList) {
// // add 65538 to ensure the set move score is positive
// move.setScore(pos.see(move) + 65538);
// }
// } else {
for (Move& move : moveList) {
// MVV-LVA capture move order
move.setScore(move.calcScore());
}
// }
std::sort(moveList.begin(), moveList.end());
// Recursively capture pieces
Position board = pos;
for (Move move : moveList) {
// SEE pruning (drop losing captures)
if (capturesOnly && (pos.see(move) < 0)) {
continue;
}
board.make(move);
score = -qSearch(board, ply + 1, -beta, -alpha, -color);
if (beta <= score) {
return beta; // fail-hard beta-cutoff
}
alpha = std::max(alpha, score);
board = pos; // unmake
}
return alpha;
}
/**
* Principal Variation Search routine (all non-root) positions
*/
template <class Evaluator>
Score PVS<Evaluator>::pvs(const PVS<Evaluator>::Position& pos, int depth, int ply,
Score alpha, Score beta, Score color, MoveList& pv_line
) {
// Check for search shutdown and in case of a repetition return draw score
if (_isStopped || isRepetition(pos)) { return 0; }
// Check break condition every few thousend nodes
if (!(_nodes % (32 * 1024))) {
// Just return 0, partial search will be discarded
if (!isRunning.load(std::memory_order_relaxed)
|| (_end_time < clock::now())) {
_isStopped = true;
return 0;
}
}
// Save original alpha bound before TT entry update
Score original_alpha = alpha;
// Move list containing all moves to be searched (usually all legal moves)
MoveList moveList;
if (!ply) {
moveList.append(_searchmoves);
}
// Transposition table lookup
auto tt_line = tt.find(pos);
if ((tt_line != tt.end()) && ((*tt_line).depth >= depth)) {
// Need to create all moves here to check if the TT move is legal
// (or restricted to searchmoves)
if (moveList.empty()) {
pos.moves(moveList);
}
// Only use the TT entry if TT move is legal
if (moveList.contains((*tt_line).move)) {
switch ((*tt_line).type) {
case lower_bound:
alpha = std::max(alpha, (*tt_line).score);
break;
case exact:
++_nodes;
pv_line.clear();
pv_line.push_back((*tt_line).move);
return (*tt_line).score;
case upper_bound:
beta = std::min(beta, (*tt_line).score);
break;
default:
assert((false && "Unknown TT node type entry in TT lookup"));
}
// Check altered alpha/beta bounds
if (beta <= alpha) {
++_nodes;
return (*tt_line).score;
}
} else {
// Reset TT entry to "not-found"
tt_line = tt.end();
}
}
// Max search depth reached
if (depth <= 0) {
return qSearch(pos, ply, alpha, beta, color);
}
// Generate all moves (if not already generated or constraint)
if (moveList.empty()) {
pos.moves(moveList);
}
// Check terminal node
if (moveList.empty()) {
// Increment node count
++_nodes;
// Evaluate terminal node score (check/stale mate)
if (pos.isCheck()) {
return color * (pos.isWhiteTurn() ? +1 : -1)
* (limits<Score>::lower() + ply);
} else {
return 0;
}
}
// Sort moves, iff hash move its the first, then captures in MVV-LVA
// followed by the rest
auto moveOrder = moveList.begin();
if (tt_line != tt.end()) {
if (moveList.move_front((*tt_line).move)) {
++moveOrder;
}
}
// score moves
for (Move& move : moveList) {
auto mScore = move.calcScore();
// Add killer move score -> killers after captures
mScore += (move == _killers[0][ply]) * Move::killerScore[0];
mScore += (move == _killers[1][ply]) * Move::killerScore[1];
// add history heuristic score -> sort of non-captures
mScore += static_cast<bool>(move.victim())
* _historyTable[move.color()][move.piece() - 2][move.to()];
move.setScore(mScore);
}
// TODO: best move selection instead of sort
std::sort(moveOrder, moveList.end(),
[](Move::base_type lhs, Move::base_type rhs) { return lhs > rhs; });
// New child pv line
MoveList cur_line;
// New TT entry (tracks "best" move)
TTEntry tt_entry{unknown, _age, depth, 0, Move{}};
// Current board
Position board = pos;
Index moveCount = 0;
Score score = limits<Score>::lower();
for (Move move : moveList) {
board.make(move);
pushHistory(board);
Score value;
if (!moveCount) {
value = -pvs(board, depth - 1, ply + 1,
-beta, -alpha, -color, cur_line);
} else {
// LMR (Late Move Reduction), reduce search depth of later moves
int R = 0; // TODO: disabled!!!
// int R = (moveCount > 3) & (ply > 2) & !move.victim() & !pos.isCheck();
// Zero-window search (with possible reduction)
value = -pvs(board, depth - 1 - R, ply + 1,
-(alpha + 1), -alpha, -color, cur_line);
if (alpha < value) {
// re-search (with clean current pv line and without reduction)
cur_line.clear();
value = -pvs(board, depth - 1, ply + 1,
-beta, -alpha, -color, cur_line);
}
}
++moveCount;
score = std::max(score, value);
// Check for beta cutoff
if (beta <= score) {
// Set possible lower bound hash hash move
tt_entry.move = move;
// Store killer move which is a quiet move causing a beta cut-off
if (!move.victim()) {
// If move isn't the first killer move we insert which ensures
// (assuming killers are different or !move) that two different
// killers are stored
if (_killers[0][ply] != move) {
_killers[1][ply] = _killers[0][ply];
_killers[0][ply] = move;
}
// Increment history heuristic
_historyTable[move.color()][move.piece() - 2][move.to()] += depth * depth;
}
popHistory();
break;
}
// Check if alpha raised; iff add new pv/best move
if (alpha < score) {
alpha = score;
// Track pv (implicitly the best move)
pv_line.clear();
pv_line.push_back(move);
pv_line.append(cur_line);
// Track best move as hash move
tt_entry.move = move;
}
board = pos; // unmake move
popHistory();
}
// Add TT entry
tt_entry.score = score;
if (score <= original_alpha) {
tt_entry.type = upper_bound;
} else if (beta <= score) {
tt_entry.type = lower_bound;
} else {
tt_entry.type = exact;
}
tt.insert(pos, tt_entry);
return score;
}
// Starts searching the given board position by dispatching a worker thread
void start(std::vector<Board>& game, State& config) {
if (isRunning.load(std::memory_order_consume)) {
// Ignore further search start attempts if allready searching, requires
// a stop first!
return;
} else {
// Dispose of finished/stopped workers
for (std::thread& worker : workers) {
if (worker.joinable()) {
worker.join();
}
}
workers.clear();
}
// sets worker thread stop condition to false (before dispatch)
isRunning.store(true, std::memory_order_release);
// Dispatch search worker
switch (config.search) {
case State::Type::perft:
state.search = State::Type::perft;
workers.emplace_back([&]() {
// Copy working variables, not subject to change
Board pos(game.back());
MoveList searchmoves = config.searchmoves;
// Start perft
perft(pos, config.depth, searchmoves);
// Reset running flag
isRunning.store(false, std::memory_order_relaxed);
});
break;
case State::Type::search:
std::cout << "info string HCE evaluation" << std::endl;
workers.emplace_back(PVS<Board>(game, config));
break;
case State::Type::ponder:
std::cerr << "info string error pondering not implemented!" << std::endl;
break;
default:
assert((false && "Search::start got request for unknown search type."));
}
}
void stop() {
// revoke isRunning flag -> workers stop as soon as possible
isRunning.store(false, std::memory_order_relaxed);
// then join and dispose all workers
for (std::thread& worker : workers) {
if (worker.joinable()) {
worker.join();
}
}
workers.clear();
}
// Explicit instantiations of PVS types using different evaluators
template class PVS<Board>;
} /* namespace Search */

View File

@ -0,0 +1,203 @@
#ifndef INCLUDE_GUARD_SEARCH_H
#define INCLUDE_GUARD_SEARCH_H
#include <iostream>
#include <iomanip>
#include <vector>
#include <chrono>
#include <atomic>
#include "types.h"
#include "Move.h"
#include "Board.h"
namespace Search {
struct State {
enum class Type { perft, search, ponder };
enum Type search;
int depth;
unsigned wtime;
unsigned btime;
unsigned winc;
unsigned binc;
unsigned movetime;
MoveList searchmoves;
State() = default;
State(const State& state)
: search{Type::search}
, depth{state.depth}
, wtime{state.wtime}
, btime{state.btime}
, winc{state.winc}
, binc{state.binc}
, movetime{state.movetime}
, searchmoves() { };
};
extern State state;
extern std::atomic<bool> isRunning;
/**
* Initialization of search specific (global) variables
*/
void init(unsigned ttSize = 32);
/**
* Diagnostic routine to get the allocated TT size in MB
*/
double ttSize();
/**
* Resets all search internal variables to initial configuration state
*/
void newgame();
/**
* Performance Test Subroutine, simple perft function without any I/O.
*
* Used in `perft.cpp` testing utility (therefore accessable).
*
* @param board Reference root node board position.
* @param depth search depth.
*
* @return number of moves/nodes from root node of search depth.
*/
unsigned perft_subroutine(Board& board, int depth);
/**
* Board wrapper in combination with evaluation to abstract the different
* position evaluation options (HCE Board evaluate and NNUE evaluate)
*
* @note The NNUE evaluation is removed from this code base!
*/
template <class CRTP>
class GameState : public Board {
public:
GameState(const Board& pos) : Board(pos) { };
Score eval() { return static_cast<CRTP*>(this)->eval(); }
};
/**
* Specialized CRTP derived HCE Board evaluation position
*/
class BoardState : public GameState<BoardState> {
public:
BoardState(const Board& pos) : GameState<BoardState>(pos) { };
Score eval() const { return this->evaluate(); }
};
/**
* Search routine class handling all search relevant data and executes the search
*
* This is a Principal Variabtion Search with alpha-beta pruning and a
* transposition table implemented in an iterative deepening framework.
*
* There are two Evaluation types, ether the HCE Board evaluation `PVS<Board>`
* or the NNUE evaluation `PVS<NNUE>`.
*
* @note The NNUE evaluation is removed from this code base!
*/
template <typename Evaluator>
class PVS {
public:
PVS(const std::vector<Board>&, const State&);
Score operator()();
Move bestMove() const { return _pv[0]; }
// TODO: there is a BUG collecting the PV!
const MoveList& pv() const { return _pv; }
// Analytic routine to get the number of searched leave nodes
unsigned nodes() const { return _nodes; }
private:
using clock = std::chrono::high_resolution_clock;
using milliseconds = std::chrono::milliseconds;
using time_point = decltype(clock::now());
using Position = BoardState;
// Root search position (HCE Board state)
Position _root;
// Max search depth (excluding QSearch)
int _max_depth;
// Selective search counter (needs tracking)
int _seldepth;
// Number of nodes visited (leave nodes)
unsigned _nodes;
// Set to exit the search as soon as possible (breaks out of all recursions)
bool _isStopped;
// Start time of the search (for break condition check)
time_point _start_time;
// Middle time between _start_time and _end_time used to check if an
time_point _mid_time;
// Time till the search should stop (if now() is bigger, stop the search)
time_point _end_time;
// Moves to be searched, if empty all legal moves are searched
MoveList _searchmoves;
// Principal Variation
MoveList _pv;
// Killer Moves (initialized to contain !move entries)
Move _killers[2][MoveList::max_size()];
// History Heuristic Table (for move ordering)
Score _historyTable[2][6][64]; // [color][piece][to Square]
// Number of elements in `_history`
size_t _historySize;
// Game history list storing previous board hashes for repetition detection
u64 _history[100 + MoveList::max_size()];
// Increasing age for TT entry
static unsigned _age;
// adds a position to the history
void pushHistory(const Position& pos) {
_history[_historySize++] = pos.hash();
}
// and removes the last position from the history
void popHistory() { --_historySize; }
// two-fold repetition detection
bool isRepetition(const Position& pos) const {
u64 hash = pos.hash();
for (size_t i = _historySize - 1; i-- > 0; ) {
if (_history[i] == hash) {
return true;
}
}
return false;
}
// Principal Variation Search routine (for not root nodes)
Score pvs(const Position& pos, int depth, int ply,
Score alpha, Score beta, Score color, MoveList& pv_line);
// Quiesence Search
Score qSearch(const Position& pos, int ply,
Score alpha, Score beta, Score color);
};
/**
* Starts searching on the given board position.
*/
void start(std::vector<Board>& game, struct State& config);
/**
* Stops all workers.
*
* Tells all workers to stop as soon as possible and closes working threads.
*
* @note: This function blocks
*/
void stop();
} /* namespace Search */
#endif /* INCLUDE_GUARD_SEARCH_H */

View File

@ -0,0 +1,38 @@
#ifndef INCLUDE_GUARD_SYNCSTREAM_H
#define INCLUDE_GUARD_SYNCSTREAM_H
#include <ostream>
#include <mutex>
class osyncstream {
public:
osyncstream(std::ostream& base_ostream)
: _base_ostream(base_ostream)
{
mtx().lock();
}
~osyncstream() {
mtx().unlock();
}
template <typename T>
osyncstream& operator<<(const T& rhs) {
_base_ostream << rhs;
return *this;
}
osyncstream& operator<<(std::ostream& (*fn)(std::ostream&)) {
_base_ostream << fn;
return *this;
}
private:
std::mutex& mtx() {
static std::mutex _mtx;
return _mtx;
};
std::ostream& _base_ostream;
};
#endif /* INCLUDE_GUARD_SYNCSTREAM_H */

View File

@ -0,0 +1,195 @@
#ifndef INCLUDE_GUARD_TYPES_H
#define INCLUDE_GUARD_TYPES_H
#include <cstdint> // uint64_t
#include <limits> // std::numeric_limits
/** square, file and rank index (index > 63 indicates illegal or off board) */
using Index = unsigned;
/** Bit board, exactly 64 bits (one bit per square) */
using u64 = uint64_t; // easy on the eyes (well, my eyes)
/**
* Board position score from white point of view in centipawns.
* (1 pawn ~ 100 centipawns)
*/
using Score = int;
template <typename T>
struct limits {
static constexpr T upper();
static constexpr T lower();
};
template <>
struct limits<Score> {
static constexpr Score upper() { return static_cast<Score>(+32768); };
static constexpr Score lower() { return static_cast<Score>(-32768); };
static constexpr bool isMate(const Score score) {
constexpr Score mateBound = upper() - 512;
return (score < -mateBound) || (mateBound < score);
}
};
enum piece {
none = 0,
white = 0,
black = 1,
pawn = 2,
knight = 3,
bishop = 4,
rook = 5,
queen = 6,
king = 7
};
enum square : Index {
a8, b8, c8, d8, e8, f8, g8, h8,
a7, b7, c7, d7, e7, f7, g7, h7,
a6, b6, c6, d6, e6, f6, g6, h6,
a5, b5, c5, d5, e5, f5, g5, h5,
a4, b4, c4, d4, e4, f4, g4, h4,
a3, b3, c3, d3, e3, f3, g3, h3,
a2, b2, c2, d2, e2, f2, g2, h2,
a1, b1, c1, d1, e1, f1, g1, h1
};
enum location {
Square,
Up, Down, Left, Right,
QueenSide = Left, KingSide = Right,
File, Rank,
Diag, AntiDiag,
RightUp, RightDown,
LeftUp, LeftDown,
Plus, Cross, Star,
WhiteSquares, BlackSquares
};
// Material weighting per piece
constexpr Score pieceValues[8] = {
0, 0, // white, black (irrelevant)
100, // pawn
300, // knight
300, // bishop
500, // rook
900, // queen
0 // king (irrelevant, always 2 opposite kings)
};
// Move lookup tables for knights and kings
constexpr u64 knightMoveLookup[64] = {
0x0000000000020400, 0x0000000000050800, 0x00000000000a1100, 0x0000000000142200,
0x0000000000284400, 0x0000000000508800, 0x0000000000a01000, 0x0000000000402000,
0x0000000002040004, 0x0000000005080008, 0x000000000a110011, 0x0000000014220022,
0x0000000028440044, 0x0000000050880088, 0x00000000a0100010, 0x0000000040200020,
0x0000000204000402, 0x0000000508000805, 0x0000000a1100110a, 0x0000001422002214,
0x0000002844004428, 0x0000005088008850, 0x000000a0100010a0, 0x0000004020002040,
0x0000020400040200, 0x0000050800080500, 0x00000a1100110a00, 0x0000142200221400,
0x0000284400442800, 0x0000508800885000, 0x0000a0100010a000, 0x0000402000204000,
0x0002040004020000, 0x0005080008050000, 0x000a1100110a0000, 0x0014220022140000,
0x0028440044280000, 0x0050880088500000, 0x00a0100010a00000, 0x0040200020400000,
0x0204000402000000, 0x0508000805000000, 0x0a1100110a000000, 0x1422002214000000,
0x2844004428000000, 0x5088008850000000, 0xa0100010a0000000, 0x4020002040000000,
0x0400040200000000, 0x0800080500000000, 0x1100110a00000000, 0x2200221400000000,
0x4400442800000000, 0x8800885000000000, 0x100010a000000000, 0x2000204000000000,
0x0004020000000000, 0x0008050000000000, 0x00110a0000000000, 0x0022140000000000,
0x0044280000000000, 0x0088500000000000, 0x0010a00000000000, 0x0020400000000000
};
constexpr u64 kingMoveLookup[64] = {
0x0000000000000302, 0x0000000000000705, 0x0000000000000E0A, 0x0000000000001C14,
0x0000000000003828, 0x0000000000007050, 0x000000000000E0A0, 0x000000000000C040,
0x0000000000030203, 0x0000000000070507, 0x00000000000E0A0E, 0x00000000001C141C,
0x0000000000382838, 0x0000000000705070, 0x0000000000E0A0E0, 0x0000000000C040C0,
0x0000000003020300, 0x0000000007050700, 0x000000000E0A0E00, 0x000000001C141C00,
0x0000000038283800, 0x0000000070507000, 0x00000000E0A0E000, 0x00000000C040C000,
0x0000000302030000, 0x0000000705070000, 0x0000000E0A0E0000, 0x0000001C141C0000,
0x0000003828380000, 0x0000007050700000, 0x000000E0A0E00000, 0x000000C040C00000,
0x0000030203000000, 0x0000070507000000, 0x00000E0A0E000000, 0x00001C141C000000,
0x0000382838000000, 0x0000705070000000, 0x0000E0A0E0000000, 0x0000C040C0000000,
0x0003020300000000, 0x0007050700000000, 0x000E0A0E00000000, 0x001C141C00000000,
0x0038283800000000, 0x0070507000000000, 0x00E0A0E000000000, 0x00C040C000000000,
0x0302030000000000, 0x0705070000000000, 0x0E0A0E0000000000, 0x1C141C0000000000,
0x3828380000000000, 0x7050700000000000, 0xE0A0E00000000000, 0xC040C00000000000,
0x0203000000000000, 0x0507000000000000, 0x0A0E000000000000, 0x141C000000000000,
0x2838000000000000, 0x5070000000000000, 0xA0E0000000000000, 0x40C0000000000000
};
// Declare I/O streams (allows to globaly replace the I/O streams)
#ifdef RCPP_RCOUT
#include <Rcpp.h>
// Set I/O streams to Rcpp I/O streams
static Rcpp::Rostream<true> cout;
static Rcpp::Rostream<false> cerr;
#elif NULLSTREAM
#include "nullstream.h"
// Set I/O streams to "null"
static nullstream cout;
static nullstream cerr;
#else
#include <iostream>
// Default STL I/O streams
using std::cout;
using std::cerr;
#endif
// Piece Square tables (from TSCP)
// see: https://www.chessprogramming.org/Simplified_Evaluation_Function
constexpr Score pieceSquareTables[8][64] = {
{ }, { }, // white, black (empty)
{ // pawn (white)
0, 0, 0, 0, 0, 0, 0, 0,
5, 10, 15, 20, 20, 15, 10, 5,
4, 8, 12, 16, 16, 12, 8, 4,
3, 6, 9, 12, 12, 9, 6, 3,
2, 4, 6, 8, 8, 6, 4, 2,
1, 2, 3, -10, -10, 3, 2, 1,
0, 0, 0, -40, -40, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0 },
{ // knight (white)
-10, -10, -10, -10, -10, -10, -10, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, -30, -10, -10, -10, -10, -30, -10 },
{ // bishop (white)
-10, -10, -10, -10, -10, -10, -10, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 10, 10, 5, 0, -10,
-10, 0, 5, 5, 5, 5, 0, -10,
-10, 0, 0, 0, 0, 0, 0, -10,
-10, -10, -20, -10, -10, -20, -10, -10 },
{ // rook (white)
0, 0, 0, 0, 0, 0, 0, 0,
20, 20, 20, 20, 20, 20, 20, 20, // rook on seventh bonus
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0 },
{ // queen (white)
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0 },
{ // king middle game (white)
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-40, -40, -40, -40, -40, -40, -40, -40,
-20, -20, -20, -20, -20, -20, -20, -20,
0, 20, 40, -20, 0, -20, 40, 20 }
};
#endif /* INCLUDE_GUARD_TYPES_H */

View File

@ -0,0 +1,964 @@
#include <iostream>
#include <iomanip>
#include <string>
#include <sstream>
#include <vector>
#include <tuple>
#include <algorithm>
#include <cassert>
#include "types.h"
#include "Move.h"
#include "Board.h"
#include "uci.h"
#include "search.h"
#if __cplusplus > 201703L
#include <syncstream>
using osyncstream = std::osyncstream;
#else
#include "syncstream.h"
#endif
// Inject into std namespace
namespace std {
ostream& operator<<(ostream& out, const Move& move) {
if (!move) {
out << "!move";
return out;
}
if (UCI::writeLAN) {
switch (move.piece()) {
case pawn: break;
case knight: out << 'N'; break;
case bishop: out << 'B'; break;
case rook: out << 'R'; break;
case queen: out << 'Q'; break;
case king: out << 'K'; break;
default: out << '?';
}
}
out << static_cast<char>('a' + (move.from() % 8))
<< static_cast<char>('8' - (move.from() / 8));
if (UCI::writeLAN) {
switch (move.victim()) {
case pawn: out << 'x'; break;
case knight: out << "xN"; break;
case bishop: out << "xB"; break;
case rook: out << "xR"; break;
case queen: out << "xQ"; break;
default: out << '-';
}
}
out << static_cast<char>('a' + (move.to() % 8))
<< static_cast<char>('8' - (move.to() / 8));
if (UCI::writeLAN) {
switch (move.promote()) {
case rook: out << 'R'; break;
case knight: out << 'N'; break;
case bishop: out << 'B'; break;
case queen: out << 'Q'; break;
default: break;
}
} else {
switch (move.promote()) {
case rook: out << 'r'; break;
case knight: out << 'n'; break;
case bishop: out << 'b'; break;
case queen: out << 'q'; break;
default: break;
}
}
return out;
}
ostream& operator<<(ostream& out, const MoveList& moveList) {
if (moveList.size()) {
out << moveList[0];
}
for (unsigned i = 1; i < moveList.size(); i++) {
out << ' ' << moveList[i];
}
return out;
}
}
namespace UCI {
// General UCI protocol configuration
bool readSAN = false; // if moves should be read in SAN
bool writeLAN = false; // moves written in Long Algebraic Notation
Index parseSquare(const std::string& str, bool& parseError) {
if (str.length() != 2) {
parseError = true;
return static_cast<Index>(64);
}
if (str[0] < 'a' || 'h' < str[0] || str[1] < '1' || '8' < str[1]) {
parseError = true;
return static_cast<Index>(64);
}
Index fileIndex = static_cast<Index>(str[0] - 'a');
Index rankIndex = static_cast<Index>('8' - str[1]);
return 8 * rankIndex + fileIndex;
}
Move parseMove(const std::string& str, bool& parseError) {
// check length <from><to> = 4 plus 1 if promotion
if ((str.length() < 4) || (5 < str.length())) {
parseError = true;
return Move(0);
}
// parse <from><to> squares
Index fromSq = parseSquare(str.substr(0, 2), parseError);
if (parseError) { return Move(0); }
Index toSq = parseSquare(str.substr(2, 2), parseError);
if (parseError) { return Move(0); }
// handle promotions
if (str.length() == 5) {
switch (str[4]) {
case 'r': return Move(fromSq, toSq, rook);
case 'n': return Move(fromSq, toSq, knight);
case 'b': return Move(fromSq, toSq, bishop);
case 'q': return Move(fromSq, toSq, queen);
default: parseError = true; break;
}
}
return Move(fromSq, toSq);
}
const std::string formatMove(const Move move) {
// redirect parsing to `<<` overload
std::ostringstream out;
out << move;
// return as a string
return out.str();
}
// <SAN piece moves> ::= <Piece symbol>[<from file>|<from rank>|<from square>]['x']<to square>
// <SAN pawn captures> ::= <from file>[<from rank>] 'x' <to square>[<promoted to>]
// <SAN pawn push> ::= <to square>[<promoted to>]
Move parseSAN(const std::string& str, const Board& pos, bool& parseError) {
// check length, at least 2 characters for pawn pushes and max 6 plus 2
// additional characters as move annotations (for example '!?', '+' or '#').
if ((str.length() < 2) || (9 < str.length())) {
parseError = true;
return Move(0);
}
// Start with extracting the piece type.
// The first letter (strictly speaking upper case) gives the moving
// piece type or a pawn if not given.
enum piece color = pos.isWhiteTurn() ? white : black;
enum piece piece = pawn;
switch(str[0]) {
// Piece moves (all normal piece moves)
case 'R': piece = rook; break;
case 'N': piece = knight; break;
case 'B': piece = bishop; break;
case 'Q': piece = queen; break;
case 'K': piece = king; break;
// Castling
case 'O': case '0':
if (startsWith(str, "O-O-O") || startsWith(str, "0-0-0")) {
if (color == white) {
return Move(e1, c1);
} else {
return Move(e8, c8);
}
} else if (startsWith(str, "O-O") || startsWith(str, "0-0")) {
if (color == white) {
return Move(e1, g1);
} else {
return Move(e8, g8);
}
} else {
parseError = true;
return Move(0);
}
break;
// Pawn moves always begin with from/to file
case 'a': case 'b': case 'c': case 'd':
case 'e': case 'f': case 'g': case 'h':
break;
default:
parseError = true;
return Move(0);
}
// Normalize moves.
// Drop annotation symbols '?', '+', ... as well as '=', '/' for promotions.
// Thats quite generes but as long as the reminder leads to pseudo legal
// moves, I'm happy.
char part[6];
int part_length = 0;
for (char c : str) {
switch (c) {
case '+': case '#': case '!': case '?': case '=': case '/':
continue;
default:
part[part_length++] = c;
}
if (part_length >= 6) {
break;
}
}
if (piece == pawn) {
// Pawn push
if ((part_length == 2) || (part_length == 3)) {
Index toSq = parseSquare(std::string{part[0], part[1]}, parseError);
if (parseError) { return Move(0); }
u64 from;
if (color == white) {
from = fill7dump<Down>(bitMask<Square>(toSq), ~pos.bb(pawn));
from &= pos.bb(white);
} else {
from = fill7dump<Up>(bitMask<Square>(toSq), ~pos.bb(pawn));
from &= pos.bb(black);
}
if (bitCount(from) != 1) {
parseError = true;
return Move(0);
}
Index fromSq = bitScanLS(from);
if (part_length == 3) {
switch(toLower(part[2])) {
case 'q': return Move(fromSq, toSq, queen);
case 'r': return Move(fromSq, toSq, rook);
case 'b': return Move(fromSq, toSq, bishop);
case 'n': return Move(fromSq, toSq, knight);
default:
parseError = true;
return Move(0);
}
}
return Move(fromSq, toSq);
// Pawn Capture
} else if ((part_length == 4) || (part_length == 5)) {
if (part[1] != 'x') {
parseError = true;
return Move(0);
}
Index toSq = parseSquare(std::string{part[2], part[3]}, parseError);
if (parseError) { return Move(0); }
u64 from;
if (('a' <= part[0]) && (part[0] <= 'h')) {
from = bitMask<File>(static_cast<Index>(str[0] - 'a'));
} else {
parseError = true;
return Move(0);
}
if (color == white) {
from &= shift<Down>(bitMask<Rank>(toSq));
from &= pos.bb(white) & pos.bb(pawn);
} else {
from &= shift<Up>(bitMask<Rank>(toSq));
from &= pos.bb(black) & pos.bb(pawn);
}
if (bitCount(from) != 1) {
parseError = true;
return Move(0);
}
Index fromSq = bitScanLS(from);
if (part_length == 5) {
switch(toLower(part[4])) {
case 'q': return Move(fromSq, toSq, queen);
case 'r': return Move(fromSq, toSq, rook);
case 'b': return Move(fromSq, toSq, bishop);
case 'n': return Move(fromSq, toSq, knight);
default:
parseError = true;
return Move(0);
}
}
return Move(fromSq, toSq);
}
// Illegal length, ether 2, 3 for pawn push or 4, 5 for pawn captures
parseError = true;
return Move(0);
}
// Only normal piece moves remain (castling and pawn moves handled above)
if ((part_length < 3) || (6 < part_length)) {
parseError = true;
return Move(0);
}
Index toSq = parseSquare(
std::string{part[part_length - 2], part[part_length - 1]}, parseError);
if (parseError) { return Move(0); }
u64 to = bitMask<Square>(toSq);
u64 from = pos.bb(color) & pos.bb(piece);
u64 empty = ~(pos.bb(white) | pos.bb(black));
switch(piece) {
case queen:
from &= fill7dump<Left> (to, empty)
| fill7dump<Right> (to, empty)
| fill7dump<Up> (to, empty)
| fill7dump<Down> (to, empty)
| fill7dump<LeftUp> (to, empty)
| fill7dump<RightUp> (to, empty)
| fill7dump<LeftDown> (to, empty)
| fill7dump<RightDown>(to, empty);
break;
case rook:
from &= fill7dump<Left> (to, empty)
| fill7dump<Right>(to, empty)
| fill7dump<Up> (to, empty)
| fill7dump<Down> (to, empty);
break;
case bishop:
from &= fill7dump<LeftUp> (to, empty)
| fill7dump<RightUp> (to, empty)
| fill7dump<LeftDown> (to, empty)
| fill7dump<RightDown>(to, empty);
break;
case knight:
from &= pos.knightMoves(toSq);
break;
case king:
from &= pos.kingMoves(toSq);
break;
default:
assert(false && "Unreachable!");
}
// Capture Move
if (part[part_length - 3] == 'x') {
Index fromSq;
switch (part_length) {
case 4:
break;
case 5:
if (('1' <= part[1]) && (part[1] <= '8')) {
from &= bitMask<Rank>(8 * static_cast<Index>('8' - part[1]));
} else if (('a' <= part[1]) && (part[1] <= 'h')) {
from &= bitMask<File>(static_cast<Index>(part[1] - 'a'));
} else {
parseError = true;
return Move(0);
}
break;
case 6:
fromSq = parseSquare(std::string{part[1], part[2]}, parseError);
if (parseError) { return Move(0); }
from &= bitMask<Square>(fromSq);
break;
default:
parseError = true;
return Move(0);
}
// Non Capture / Quiet Move
} else {
Index fromSq;
switch (part_length) {
case 3:
break;
case 4:
if (('1' <= part[1]) && (part[1] <= '8')) {
from &= bitMask<Rank>(8 * static_cast<Index>('8' - part[1]));
} else if (('a' <= part[1]) && (part[1] <= 'h')) {
from &= bitMask<File>(static_cast<Index>(part[1] - 'a'));
} else {
parseError = true;
return Move(0);
}
break;
case 5:
fromSq = parseSquare(std::string{part[1], part[2]}, parseError);
if (parseError) { return Move(0); }
from &= bitMask<Square>(fromSq);
break;
default:
parseError = true;
return Move(0);
}
}
if (bitCount(from) != 1) {
// If there is still an umbiguity, some candidates might be pinned.
// Remove pinned candidates unable to reach the target square.
u64 cKing = pos.bb(color) & pos.bb(king);
Index cKingSq = bitScanLS(cKing);
u64 pinMask = pos.pinned((color == white) ? black : white, cKing, empty);
for (u64 can = from; can; can &= can - 1) { // can ... candidates
Index canSq = bitScanLS(can);
// check if current candidate is pinned
if ((can & -can) & pinMask) {
// validate if tha candidate is able to reach the target square
if (!(lineMask(cKingSq, canSq) & to)) {
// Remove candidate
from ^= can & -can;
}
}
}
// If there are still multiple candidates or none -> parse error
if (bitCount(from) != 1) {
parseError = true;
return Move(0);
}
}
// Finaly, create the move
return Move(bitScanLS(from), toSq);
}
void parseMoves(std::istream& istream, const Board& pos, MoveList& moveList,
bool& parseError
) {
// setup movelist of all legal moves to validate move lagality
MoveList legalMoves;
pos.moves(legalMoves);
// holds input tokens to be parsed as moves
std::string word;
istream >> std::skipws;
// For each input token
while (istream >> word) {
// parse token as a move (depending on the move format)
Move move;
if (readSAN) {
move = parseSAN(word, pos, parseError);
} else {
move = parseMove(word, parseError);
}
// stop in case of a parse error
if (parseError) { return; }
// validate legality of the (successfully parsed) move and add the
// generated (not the parsed move) to the move list containing all the
// additional move information
bool isLegal = false;
for (Move legal : legalMoves) {
if ((move.from() == legal.from())
&& (move.to() == legal.to())
&& (move.promote() == legal.promote()))
{
isLegal = true;
moveList.push_back(legal);
break;
}
}
// if the parse move was not found in the legal moves list stop and
// report a parse error
if (!isLegal) {
parseError = true;
return;
}
}
}
char formatPiece(enum piece piece) {
switch (piece) {
case rook: return 'R';
case knight: return 'N';
case bishop: return 'B';
case queen: return 'Q';
case king: return 'K';
case pawn: return 'P';
default: return '?';
}
}
void printBoard(const Board& board) {
osyncstream out(cout);
const char* rankNames = "87654321";
const char* fileNames = "abcdefgh";
for (Index line = 0; line < 17; line++) {
if (line % 2) {
Index rankIndex = line / 2;
out << rankNames[rankIndex];
for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
Index squareIndex = 8 * rankIndex + fileIndex;
if (board.color(squareIndex) == black) {
out << " | \033[1m\033[94m"
<< formatPiece(board.piece(squareIndex)) << "\033[0m";
} else if (board.color(squareIndex) == white) {
out << " | \033[1m\033[97m"
<< formatPiece(board.piece(squareIndex)) << "\033[0m";
} else {
out << " | ";
}
}
out << " |";
} else {
out << " +---+---+---+---+---+---+---+---+";
}
out << " ";
switch (line) {
case 1:
out << "How's turn: "
<< (board.isWhiteTurn() ? "white" : "black");
break;
case 2:
out << "Move Count: " << ((board.plyCount() + 1U) / 2U);
break;
case 3:
out << "Half move clock: " << board.halfMoveClock();
break;
case 4:
out << "Castling Rights: ";
if (board.castleRight(white, KingSide)) { out << 'K'; };
if (board.castleRight(white, QueenSide)) { out << 'Q'; };
if (board.castleRight(black, KingSide)) { out << 'k'; };
if (board.castleRight(black, QueenSide)) { out << 'q'; };
break;
case 5:
out << "en-passange target: ";
if (board.enPassant() < 64) {
out << static_cast<char>('a' + (board.enPassant() % 8))
<< static_cast<char>('8' - (board.enPassant() / 8));
}
else { out << "-"; };
break;
case 6:
out << "evaluate: " << std::dec << board.evaluate();
break;
case 7:
out << "hash: " << std::hex << board.hash() << std::dec;
break;
default:
break;
}
out << "\033[0K\n"; // clear rest of line (remove potential leftovers)
}
out << " ";
for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
out << " " << fileNames[fileIndex];
}
out << std::endl;
}
void printBitBoards(const Board& board) {
osyncstream out(cout);
std::string lRanks(" 8\n 7\n 6\n 5\n 4\n 3\n 2\n 1");
std::string rRanks("8\n7\n6\n5\n4\n3\n2\n1");
out << std::endl << " "
<< std::left << std::setw(18) << "white"
<< std::left << std::setw(18) << "black"
<< std::left << std::setw(18) << "pawns"
<< std::left << std::setw(18) << "kings"
<< std::endl << " " << std::setfill('0') << std::hex
<< std::right << std::setw(16) << board.bb(white) << " "
<< std::right << std::setw(16) << board.bb(black) << " "
<< std::right << std::setw(16) << board.bb(pawn) << " "
<< std::right << std::setw(16) << board.bb(king) << " "
<< std::left << std::setfill(' ') << std::dec << std::endl
<< aside(
lRanks,
rbits(board.bb(white), '\n', ' '),
rbits(board.bb(black), '\n', ' '),
rbits(board.bb(pawn), '\n', ' '),
rbits(board.bb(king), '\n', ' '),
rRanks
) << std::endl
<< " a b c d e f g h a b c d e f g h"
<< " a b c d e f g h a b c d e f g h" << std::endl;
out << std::endl << " "
<< std::left << std::setw(18) << "rooks"
<< std::left << std::setw(18) << "knights"
<< std::left << std::setw(18) << "bishops"
<< std::left << std::setw(18) << "queens"
<< std::endl << " " << std::setfill('0') << std::hex
<< std::right << std::setw(16) << board.bb(rook) << " "
<< std::right << std::setw(16) << board.bb(knight) << " "
<< std::right << std::setw(16) << board.bb(bishop) << " "
<< std::right << std::setw(16) << board.bb(queen) << " "
<< std::left << std::setfill(' ') << std::dec << std::endl
<< aside(
lRanks,
rbits(board.bb(rook), '\n', ' '),
rbits(board.bb(knight), '\n', ' '),
rbits(board.bb(bishop), '\n', ' '),
rbits(board.bb(queen), '\n', ' '),
rRanks
) << std::endl
<< " a b c d e f g h a b c d e f g h"
<< " a b c d e f g h a b c d e f g h" << std::endl;
const enum piece color = board.isWhiteTurn() ? white : black;
const enum piece enemy = (color == white) ? black : white;
const u64 empty = ~(board.bb(white) | board.bb(black));
const u64 cKing = board.bb(color) & board.bb(king);
out << std::endl << " "
<< std::left << std::setw(18) << "attacks"
<< std::left << std::setw(18) << "pinned"
<< std::left << std::setw(18) << "checkers"
<< std::endl
<< aside(
lRanks,
rbits(board.attacks(enemy, empty), '\n', ' '),
rbits(board.pinned(enemy, cKing, empty), '\n', ' '),
rbits(board.checkers(enemy, cKing, empty), '\n', ' '),
rRanks
) << std::endl
<< " a b c d e f g h a b c d e f g h"
<< " a b c d e f g h" << std::endl;
}
void printEval(const Board& board) {
osyncstream out(cout);
// Partial Scores
Score pScoreWhite, pScoreBlack;
out << " | White | Black | Total\n"
<< "-------------+---------+---------+---------\n"
<< " Material | ";
pScoreWhite = board.evalMaterial(white);
pScoreBlack = board.evalMaterial(black);
out << std::setw(7) << std::right << pScoreWhite << " | "
<< std::setw(7) << std::right << pScoreBlack << " | "
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
<< " Pawns | ";
pScoreWhite = board.evalPawns(white);
pScoreBlack = board.evalPawns(black);
out << std::setw(7) << std::right << pScoreWhite << " | "
<< std::setw(7) << std::right << pScoreBlack << " | "
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
<< " King Safety | ";
pScoreWhite = board.evalKingSafety(white);
pScoreBlack = board.evalKingSafety(black);
out << std::setw(7) << std::right << pScoreWhite << " | "
<< std::setw(7) << std::right << pScoreBlack << " | "
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n"
<< " Rooks | ";
pScoreWhite = board.evalRooks(white);
pScoreBlack = board.evalRooks(black);
out << std::setw(7) << std::right << pScoreWhite << " | "
<< std::setw(7) << std::right << pScoreBlack << " | "
<< std::setw(7) << std::right << pScoreWhite - pScoreBlack << "\n";
out << "\nTotal: " << board.evaluate() << std::endl;
}
void printMoves(const Board& board, bool capturesOnly) {
MoveList moveList;
board.moves(moveList, capturesOnly);
for (Move& move : moveList) {
move.setScore(move.calcScore());
}
std::sort(moveList.begin(), moveList.end());
osyncstream out(cout);
if (capturesOnly) {
out << "info string captures";
} else {
out << "info string moves";
}
for (Move& move : moveList) {
out << ' ' << move;
}
out << std::endl;
}
// position [fen <fenstring> | startpos ] moves <move1> <move2> .... <moveN>
// \______________________ remainder in cmd ______________________/
void position(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
std::string word;
// setup a new game in case of a parse error (no changes in case of an error)
std::vector<Board> newGame;
// Setup position (and consume moves cmd)
Board pos;
cmd >> word;
if (word == "startpos") {
// Consume and check "moves" word
if (cmd >> word && word != "moves") {
parseError = true;
return;
}
// set position to start position
pos = Board();
newGame.push_back(pos);
} else if (word == "fen") {
std::string fen;
while (cmd >> word && word != "moves") {
fen += word + " ";
}
pos.init(fen, parseError);
if (parseError) { return; }
newGame.push_back(pos);
} else if (word == "this") { // (UCI extention)
if (game.empty()) {
parseError = true;
return;
}
pos = game.back();
// Consume and check "moves" word
if (cmd >> word && word != "moves") {
parseError = true;
return;
}
// keep the game
newGame = game;
} else {
parseError = true;
return;
}
// Apply (legal) moves (if any) and append new positions
while (cmd >> word) {
Move move;
if (readSAN) {
move = UCI::parseSAN(word, pos, parseError);
} else {
move = UCI::parseMove(word, parseError);
}
// validate legality (and extend the move with piece, victim, ... info)
move = pos.isLegal(move); // validate move legality and extend move info
if (parseError || !move) {
parseError = true;
return;
}
pos.make(move);
newGame.push_back(pos);
}
// Finaly replace game with new game (no errors occured)
game = newGame;
}
void setoption(std::istream& cmd, bool& parseError) {
std::string word;
// Consume "name"
if ((!(cmd >> word)) || (word != "name")) {
parseError = true;
return;
}
// read option id
std::string id = "";
while ((cmd >> word) && (word != "value")) {
id += word;
}
// set option value
if (id == "Hash") {
Index ttSize = 0;
cmd >> ttSize;
if ((0 < ttSize) && (ttSize < 1025)) {
Search::init(ttSize);
} else {
parseError = true;
}
} else if (id == "readSAN") {
cmd >> word;
if (word == "true") {
readSAN = true;
} else if (word == "false") {
readSAN = false;
} else {
parseError = true;
}
} else if (id == "writeLAN") {
cmd >> word;
if (word == "true") {
writeLAN = true;
} else if (word == "false") {
writeLAN = false;
} else {
parseError = true;
}
} else {
parseError = true;
}
}
// go <cmd>
// with <cmd>;
// ...
// TODO: implement
//
void go(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
std::string word;
Search::State config(Search::state);
std::string IGNORED; // TODO: complete UCI implementation
while ((cmd >> word) && !parseError) {
if (word == "perft") {
config.search = Search::State::Type::perft;
if (!(cmd >> config.depth)) {
parseError = true;
}
}
else if (word == "searchmoves") {
parseMoves(cmd, game.back(), config.searchmoves, parseError);
}
else if (word == "depth") { cmd >> config.depth; }
else if (word == "wtime") { cmd >> config.wtime; }
else if (word == "btime") { cmd >> config.btime; }
else if (word == "winc") { cmd >> config.winc; }
else if (word == "binc") { cmd >> config.binc; }
else if (word == "movestogo") { cmd >> IGNORED; }
else if (word == "nodes") { cmd >> IGNORED; }
else if (word == "mate") { cmd >> IGNORED; }
else if (word == "movetime") { cmd >> config.movetime; }
else if (word == "ponder") { cmd >> IGNORED; }
else if (word == "infinite") {
config.depth = MoveList::max_size(); // max PV line length
config.wtime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
config.btime = 24 * 60 * 60 * 1000; // 1 day in milliseconds
}
else {
parseError = true;
}
}
if (parseError) {
return;
}
// Write new global settings to Search state
Search::state.wtime = config.wtime;
Search::state.btime = config.btime;
Search::state.winc = config.winc;
Search::state.binc = config.binc;
Search::start(game, config);
}
void print(std::vector<Board>& game, std::istream& cmd, bool& parseError) {
Board& board = game.back();
std::string word;
// Get print command (or set default)
if (!(cmd >> word)) {
word = "board";
}
if (word == "board") { printBoard(board); }
else if (word == "game") { for (Board& pos : game) { printBoard(pos); } }
else if (word == "moves") { printMoves(board); }
else if (word == "eval") { printEval(board); }
else if (word == "captures") { printMoves(board, true); }
else if (word == "bits") { printBitBoards(board); }
else if (word == "fen") {
osyncstream(cout) << "info string fen "
<< board.fen() << std::endl; }
else { parseError = true; }
}
void stdin_listen(std::vector<Board>& game) {
// read line by line from stdin and dispatch appropriate command handler
std::string line;
while (getline(std::cin, line)) {
// a game consists of at least one position
assert(game.size());
std::istringstream cmd(line);
std::string cmdName;
// Extract command name (skip white spaces)
cmd >> std::skipws >> cmdName;
// Dispatch commands (or handle directly)
bool parseError = false; // tracks comand argument parse status
if (cmdName == "quit") { Search::stop(); break; } // stop stdin loop -> shutdown
else if (cmdName == "exit") { Search::stop(); break; } // quit alias
else if (cmdName == "stop") { Search::stop(); }
else if (cmdName == "ucinewgame") { Search::newgame(); }
else if (cmdName == "uci") {
osyncstream(cout)
<< "id name Schach Hoernchen"
"\nid author Daniel Kapla"
"\noption name Hash type spin default 32 min 1 max 1024"
"\noption name readSAN type check default false"
"\noption name writeLAN type check default false"
"\nuciok" // Ready (there are no options yet)
<< std::endl;
}
else if (cmdName == "isready") { cout << "readyok" << std::endl; }
else if (cmdName == "setoption") { setoption(cmd, parseError); }
else if (cmdName == "go") { go(game, cmd, parseError); }
else if (cmdName == "position") { position(game, cmd, parseError); }
// UCI Extention (not part of the UCI protocol)
else if (cmdName == "d") { print(game, cmd, parseError); } // print alias (as in stockfish)
else if (cmdName == "print") { print(game, cmd, parseError); }
else if (cmdName == "getoptions") {
osyncstream(cout)
<< "option name Hash value " << std::setprecision(2)
<< Search::ttSize() << " MB"
<< "\noption name readSAN value " << std::boolalpha << readSAN
<< "\noption name writeLAN value " << std::boolalpha << writeLAN
<< std::endl;
}
else if (cmdName == "help" || cmdName == "?") {
osyncstream(std::cerr)
<< "Commands:\n"
" uci\n\033[2m"
" responds with self identification/options and finishes "
"with uciok\033[0m\n"
" ucinewgame\n\033[2m"
" starts a new game; resets all internal game states\033[0m\n"
" isready\n\033[2m"
" Should be responding 'readyok' immediately\033[0m\n"
" position\n"
" position startpos [moves <move1> [<move2> ...]]\n"
" position fen <fen> [moves <move1> [<move2> ...]]\n"
" position this moves <move1> [<move2> ...]\n"
" go\n"
" go perft <depth>\n"
" go [depth <depth>] [searchmoves <move> [<move> ...]]\n"
" stop\n\033[2m"
" Stops what ever the engine is doing right now as "
"soon as possible\033[0m\n"
" print\n"
" d\n"
" print [board]\033[2m\n"
" Prity prints the board with state information\033[0m\n"
" print game\033[2m\n"
" Prity prints entire game history\033[0m\n"
" print moves\033[2m\n"
" Gives a list of all legal moves sorted by move ordering "
"heuristic\033[0m\n"
" print captures\033[2m\n"
" Same as print moves but captures only\033[0m\n"
" print bits\033[2m\n"
" Prints the internal bit-boards plus attacks, pinnned and "
"checkers bit-boards\033[0m\n"
" print fen\033[2m\n"
" Print FEN string of current internal state\033[0m\n"
" quit\n"
" exit\033[2m\n"
" Shutdown the engine as soon as possible\033[0m\n"
" setoption\n"
" setoption name Hash value <size>\n"
" setoption name readSAN value [true | false]\n"
" setoption name writeLAN value [true | false]\n"
" getoptions\n\033[2m"
" Prints set values of all options\033[0m\n"
<< std::endl; // TODO: complete help!!!
} else {
osyncstream(std::cerr)
<< "info string error Unknown command!" << std::endl;
}
if (parseError) {
osyncstream(std::cerr)
<< "info string error Missformed or illegal command!"
<< std::endl;
}
}
osyncstream(cout) << "info string shutdown" << std::endl;
}
} /* namespace UCI */

View File

@ -0,0 +1,103 @@
#ifndef UCI_GUARD_MOVE_H
#define UCI_GUARD_MOVE_H
#include <string>
#include <vector>
#include "types.h"
// Forward declarations
class Board;
class Move;
class MoveList;
namespace std {
ostream& operator<<(ostream& out, const Move& move);
ostream& operator<<(ostream& out, const MoveList& moveList);
}
/**
* UCI ... Universal Chess Interface
*
* Implements the UCI interface for interaction between the internal engine
* and an GUI using the UCI standard.
*/
namespace UCI {
// General UCI protocol configuration
extern bool readSAN; // if moves should be read in SAN
extern bool writeLAN; // moves written in Long Algebraic Notation
/**
* Parses square in algebraic notation as internal square index.
*
* @param str string representation of a square /[a-h][1-8]/.
* @param parseError output variable set to true if illegal square
* representation is encountered (aka. parse error occured).
*
* @returns index between 0 and 63 if legal, 64 otherwise.
*/
Index parseSquare(const std::string& str, bool& parseError);
/**
* Parses a move given in pure coordinate notation
*
* @param str string representation of a move in `<from><to>[<promotion>]` format
* @param parseError output variable set to false if successfully parsed, true
* otherwise.
*/
Move parseMove(const std::string& str, bool& parseError);
/**
* Formats a move in coordinate notation (Unary Function equiv to the pipe
* operator for Move)
*
* @param move move to be formated as a string
*
* @return string representation of the move in coordinate notation
*/
const std::string formatMove(const Move move);
/**
* Parses a move given in standard algebraic notation (SAN)
*
* @param str string representation of a move in SAN format
* @param pos current position, needed for move interpretation
* @param parseError output variable set to false if successfully parsed, true
* otherwise.
*/
Move parseSAN(const std::string& str, const Board& pos, bool& parseError);
/**
* Parses multiple moves from an input stream into a `MoveList`
*/
void parseMoves(std::istream&, MoveList&, bool&);
char formatPiece(enum piece piece);
void printBoard(const Board&);
void printMoves(const Board&, bool = false);
void printBitBoards(const Board&);
/**
* Handles `position ...` command by setting given board (including moves).
*
* @param board internal board representation to be manipulated.
* @param cmd argument `...` of the position command. For example; given the
* command `position startpos moves e2e4`, str should contain `startpos moves e2e4`.
* @param parseError output, true if a parse error occured, false otherwise.
*/
void position(std::vector<Board>&, std::istream&, bool&);
/**
* handles `go ...` commands. Starts searches, performance tests, ...
*/
void go(std::vector<Board>&, std::istream&, bool&);
/**
* Processes/Parses input from stdin and dispatches appropriate command handler.
*/
void stdin_listen(std::vector<Board>& game);
} /* namespace UCI */
#endif /* UCI_GUARD_MOVE_H */

View File

@ -0,0 +1,698 @@
#ifndef INCLUDE_GUARD_UTILS_H
#define INCLUDE_GUARD_UTILS_H
#include <cstdint>
#include <cstdlib> // for strto* (instead of std::strto*)
#include <vector>
#include <string>
#include <sstream>
#include <cassert>
#include <iostream>
#include <iomanip>
#include "types.h"
// Helpfull in debugging
template <typename T>
inline std::string rbits(T mask, char sepByte = 0, char sep = 0,
Index mark = -1, char marker = 'X'
) {
std::ostringstream out;
for (unsigned i = 0; i < sizeof(T); i++) {
out << sep;
for (unsigned j = 0; j < 8; j++) {
Index index = 8 * i + j;
if (index == mark) {
out << marker;
} else {
out << ((static_cast<T>(1) << index) & mask ? '1' : '.');
}
if (sep) out << sep;
}
if (sepByte && (i < (sizeof(T) - 1))) out << sepByte;
}
return out.str();
}
template <typename T>
inline std::string bits(T mask, char sepByte = 0, char sep = 0,
Index mark = -1, char marker = 'X'
) {
std::string str(rbits<T>(mask, sepByte, sep, mark, marker));
return std::string(str.rbegin(), str.rend());
}
inline std::vector<std::string> split(const std::string& str, char delim = ' ') {
std::vector<std::string> parts;
std::istringstream iss(str);
std::string part;
while (std::getline(iss, part, delim)) {
if (!part.empty()) {
parts.push_back(part);
}
}
return parts;
}
inline std::vector<std::string> parse(const std::string& str, const std::string& format) {
std::vector<std::string> parts;
std::string part;
unsigned start = 0, end = 0, len = 0;
for (const std::string& delim : split(format, '*')) {
while (end + delim.length() < str.length()) {
for (len = 0; len < delim.length(); len++) {
if (str[end + len] != delim[len]) {
break;
}
}
if (len == delim.length()) {
parts.push_back(str.substr(start, end - start));
end += len;
start = end;
break;
} else {
end++;
}
}
}
parts.push_back(str.substr(end));
return parts;
}
inline std::string aside(std::string lhs, std::string rhs) {
std::ostringstream out;
std::vector<std::string> lLines = split(lhs, '\n');
std::vector<std::string> rLines = split(rhs, '\n');
std::size_t lMax = 0;
for (std::size_t i = 0; i < lLines.size(); i++) {
lMax = std::max(lMax, lLines[i].length());
}
for (std::size_t i = 0, n = std::min(lLines.size(), rLines.size()); i < n; i++) {
out << lLines[i]
<< std::string(lMax - lLines[i].length() + 1, ' ')
<< rLines[i] << '\n';
}
if (lLines.size() < rLines.size()) {
for (std::size_t i = lLines.size(); i < rLines.size(); i++) {
out << std::string(lMax + 1, ' ')
<< rLines[i] << '\n';
}
} else {
for (std::size_t i = rLines.size(); i < lLines.size(); i++) {
out << lLines[i] << '\n';
}
}
return out.str();
}
template <typename... Args>
inline std::string aside(std::string col1, std::string col2, Args... cols) {
return aside(
col1,
aside(col2, cols...)
);
}
inline void fourBitBoards(
const std::string& title1, const u64 bb1,
const std::string& title2, const u64 bb2,
const std::string& title3, const u64 bb3,
const std::string& title4, const u64 bb4,
Index mark = -1, char marker = 'X'
) {
std::string lRanks(" 8\n 7\n 6\n 5\n 4\n 3\n 2\n 1");
std::string rRanks("8\n7\n6\n5\n4\n3\n2\n1");
cout << std::endl << " "
<< std::left << std::setw(18) << title1
<< std::left << std::setw(18) << title2
<< std::left << std::setw(18) << title3
<< std::left << std::setw(18) << title4
<< std::endl
<< aside(
lRanks,
rbits(bb1, '\n', ' ', mark, marker),
rbits(bb2, '\n', ' ', mark, marker),
rbits(bb3, '\n', ' ', mark, marker),
rbits(bb4, '\n', ' ', mark, marker),
rRanks
) << std::endl
<< " a b c d e f g h a b c d e f g h"
<< " a b c d e f g h a b c d e f g h" << std::endl;
}
inline char toLower(const char c) {
if ('A' <= c && c <= 'Z') {
return c + ' ';
}
return c;
}
inline char toUpper(const char c) {
if ('a' <= c && c <= 'z') {
return c - ' ';
}
return c;
}
inline bool isUpper(const char c) {
return ('A' <= c && c <= 'Z');
}
inline bool isLower(const char c) {
return ('a' <= c && c <= 'z');
}
inline bool startsWith(const std::string& str, const std::string& prefix) {
if (str.length() < prefix.length()) {
return false;
}
for (unsigned i = 0; i < prefix.length(); i++) {
if (str[i] != prefix[i]) {
return false;
}
}
return true;
}
// see: https://stackoverflow.com/questions/194465/how-to-parse-a-string-to-an-int-in-c
// see: https://en.cppreference.com/w/cpp/string/byte/strtol
// Note: T must be a unsigned type
// Only applicable for types smaller/equal to unsigned long long.
// (only used for uint8_t, uint16_t, uint32_t and uint64_t)
template <typename T>
T parseUnsigned(const std::string& str, bool& parseError) {
if (str.empty()) {
parseError = true;
return static_cast<T>(-1);
}
char* end;
errno = 0;
unsigned long long num = strtoull(str.c_str(), &end, 10);
if (errno == ERANGE) {
errno = 0;
parseError = true;
return static_cast<T>(-1);
}
if (*end != '\0') {
parseError = true;
return static_cast<T>(-1);
}
// check overflow (using complement trick for unsigned integer types)
if (num > static_cast<unsigned long long>(static_cast<T>(-1))) {
parseError = true;
return static_cast<T>(-1);
}
return static_cast<T>(num);
}
template <enum location>
constexpr u64 bitMask();
template <enum location>
constexpr u64 bitMask(Index);
template <enum location L>
inline constexpr u64 bitMask(Index file, Index rank) {
return bitMask<L>(8 * rank + file);
}
template <>
inline constexpr u64 bitMask<Square>(Index sq) {
return static_cast<u64>(1) << sq;
}
template <>
inline constexpr u64 bitMask<File>(Index sq) {
return static_cast<u64>(0x0101010101010101) << (sq & 7);
}
template <>
inline constexpr u64 bitMask<Rank>(Index sq) {
return static_cast<u64>(0xFF) << (sq & 56);
}
template <>
inline constexpr u64 bitMask<Left>(Index sq) {
return bitMask<Rank>(sq) & (bitMask<Square>(sq) - 1);
}
template <>
inline constexpr u64 bitMask<Right>(Index sq) {
return bitMask<Rank>(sq) & (static_cast<u64>(-2) << sq);
}
template <>
inline constexpr u64 bitMask<Up>(Index sq) {
return bitMask<File>(sq) & (bitMask<Square>(sq) - 1);
}
template <>
inline constexpr u64 bitMask<Down>(Index sq) {
return bitMask<File>(sq) & (static_cast<u64>(-2) << sq);
}
template <>
inline constexpr u64 bitMask<Diag>(Index sq) {
const u64 diag = static_cast<u64>(0x8040201008040201);
int offset = 8 * static_cast<int>(sq & 7) - static_cast<int>(sq & 56);
int nort = -offset & ( offset >> 31);
int sout = offset & (-offset >> 31);
return (diag >> sout) << nort;
}
template <>
inline constexpr u64 bitMask<AntiDiag>(Index sq) {
const u64 diag = static_cast<u64>(0x0102040810204080);
int offset = 56 - 8 * static_cast<int>(sq & 7) - static_cast<int>(sq & 56);
int nort = -offset & ( offset >> 31);
int sout = offset & (-offset >> 31);
return (diag >> sout) << nort;
}
template <>
inline constexpr u64 bitMask<WhiteSquares>() {
return 0x55AA55AA55AA55AA;
}
template <>
inline constexpr u64 bitMask<BlackSquares>() {
return 0xAA55AA55AA55AA55;
}
/**
* Shifts with respect to "off board" shifting
*
* bits shift<Left>(bits) shift<Up>(bits)
* 8 . . . 1 . . . . . . 1 . . . . . . . . . . . . . 8
* 7 . . . . . . . . . . . . . . . . . . . . . . . . 7
* 6 . . . . . . . . . . . . . . . . 1 . . . . 1 . . 6
* 5 1 . . . . 1 . . . . . . 1 . . . . . . . . . . . 5
* 4 . . . . . . . . . . . . . . . . . . 1 . . . . . 4
* 3 . . 1 . . . . . . 1 . . . . . . . . . . . . . . 3
* 2 . . . . . . . . . . . . . . . . . . . . . . . . 2
* 1 . . . . . . . . . . . . . . . . . . . . . . . . 1
* a b c d e f g h a b c d e f g h a b c d e f g h
*/
template <enum location>
constexpr u64 shift(u64);
template <>
inline constexpr u64 shift<Left>(u64 bits) {
return (bits >> 1) & ~bitMask<File>(h1);
}
template <>
inline constexpr u64 shift<Right>(u64 bits) {
return (bits << 1) & ~bitMask<File>(a1);
}
template <>
inline constexpr u64 shift<Down>(u64 bits) {
return bits << 8;
}
template <>
inline constexpr u64 shift<Up>(u64 bits) {
return bits >> 8;
}
template <>
inline constexpr u64 shift<RightUp>(u64 bits) {
return (bits >> 7) & ~bitMask<File>(a1);
}
template <>
inline constexpr u64 shift<RightDown>(u64 bits) {
return (bits << 9) & ~bitMask<File>(a1);
}
template <>
inline constexpr u64 shift<LeftUp>(u64 bits) {
return (bits >> 9) & ~bitMask<File>(h1);
}
template <>
inline constexpr u64 shift<LeftDown>(u64 bits) {
return (bits << 7) & ~bitMask<File>(h1);
}
/**
* Fills (including) start square till the end of the board.
*/
template <enum location>
constexpr u64 fill(u64);
/**
* bits fill<File>(bits) fill<Rank>(bits)
* 8 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 8
* 7 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 7
* 6 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 6
* 5 . . 1 . . 1 . . . . 1 . . 1 . . 1 1 1 1 1 1 1 1 5
* 4 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 4
* 3 . . 1 . . . . . . . 1 . . 1 . . 1 1 1 1 1 1 1 1 3
* 2 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 2
* 1 . . . . . . . . . . 1 . . 1 . . . . . . . . . . 1
* a b c d e f g h a b c d e f g h a b c d e f g h
*/
template <>
inline constexpr u64 fill<Up>(u64 bits) {
bits |= bits >> 8;
bits |= bits >> 16;
return bits | (bits >> 32);
}
template <>
inline constexpr u64 fill<Down>(u64 bits) {
bits |= bits << 8;
bits |= bits << 16;
return bits | (bits << 32);
}
template <>
inline constexpr u64 fill<Left>(u64 bits) {
bits |= ((bits & 0xFEFEFEFEFEFEFEFE) >> 1);
bits |= ((bits & 0xFCFCFCFCFCFCFCFC) >> 2);
return bits | ((bits & 0xF0F0F0F0F0F0F0F0) >> 4);
}
template <>
inline constexpr u64 fill<Right>(u64 bits) {
bits |= ((bits & 0x7F7F7F7F7F7F7F7F) << 1);
bits |= ((bits & 0x3F3F3F3F3F3F3F3F) << 2);
return bits | ((bits & 0x0F0F0F0F0F0F0F0F) << 4);
}
template <>
inline constexpr u64 fill<File>(u64 bits) {
return fill<Up>(bits) | fill<Down>(bits);
}
template <>
inline constexpr u64 fill<Rank>(u64 bits) {
return fill<Left>(bits) | fill<Right>(bits);
}
/**
* An attack fill (excludes the attacker but includes blocking pieces)
* see: https://www.chessprogramming.org/Dumb7Fill
*/
template <enum location>
inline u64 fill7dump(u64, u64);
template <>
inline u64 fill7dump<Left>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
flood |= attacker = (attacker >> 1) & empty;
flood |= attacker = (attacker >> 1) & empty;
flood |= attacker = (attacker >> 1) & empty;
flood |= attacker = (attacker >> 1) & empty;
flood |= attacker = (attacker >> 1) & empty;
flood |= attacker = (attacker >> 1) & empty;
return (flood >> 1) & ~bitMask<File>(h1);
}
template <>
inline u64 fill7dump<Right>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
flood |= attacker = (attacker << 1) & empty;
flood |= attacker = (attacker << 1) & empty;
flood |= attacker = (attacker << 1) & empty;
flood |= attacker = (attacker << 1) & empty;
flood |= attacker = (attacker << 1) & empty;
flood |= attacker = (attacker << 1) & empty;
return (flood << 1) & ~bitMask<File>(a1);
}
template <>
inline u64 fill7dump<Down>(u64 attacker, u64 empty) {
u64 flood = attacker;
flood |= attacker = (attacker << 8) & empty;
flood |= attacker = (attacker << 8) & empty;
flood |= attacker = (attacker << 8) & empty;
flood |= attacker = (attacker << 8) & empty;
flood |= attacker = (attacker << 8) & empty;
flood |= attacker = (attacker << 8) & empty;
return flood << 8;
}
template <>
inline u64 fill7dump<Up>(u64 attacker, u64 empty) {
u64 flood = attacker;
flood |= attacker = (attacker >> 8) & empty;
flood |= attacker = (attacker >> 8) & empty;
flood |= attacker = (attacker >> 8) & empty;
flood |= attacker = (attacker >> 8) & empty;
flood |= attacker = (attacker >> 8) & empty;
flood |= attacker = (attacker >> 8) & empty;
return flood >> 8;
}
template <>
inline u64 fill7dump<RightUp>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
flood |= attacker = (attacker >> 7) & empty;
flood |= attacker = (attacker >> 7) & empty;
flood |= attacker = (attacker >> 7) & empty;
flood |= attacker = (attacker >> 7) & empty;
flood |= attacker = (attacker >> 7) & empty;
flood |= attacker = (attacker >> 7) & empty;
return (flood >> 7) & ~bitMask<File>(a1);
}
template <>
inline u64 fill7dump<RightDown>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(a1); // block fill (avoid wrap)
flood |= attacker = (attacker << 9) & empty;
flood |= attacker = (attacker << 9) & empty;
flood |= attacker = (attacker << 9) & empty;
flood |= attacker = (attacker << 9) & empty;
flood |= attacker = (attacker << 9) & empty;
flood |= attacker = (attacker << 9) & empty;
return (flood << 9) & ~bitMask<File>(a1);
}
template <>
inline u64 fill7dump<LeftUp>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
flood |= attacker = (attacker >> 9) & empty;
flood |= attacker = (attacker >> 9) & empty;
flood |= attacker = (attacker >> 9) & empty;
flood |= attacker = (attacker >> 9) & empty;
flood |= attacker = (attacker >> 9) & empty;
flood |= attacker = (attacker >> 9) & empty;
return (flood >> 9) & ~bitMask<File>(h1);
}
template <>
inline u64 fill7dump<LeftDown>(u64 attacker, u64 empty) {
u64 flood = attacker;
empty &= ~bitMask<File>(h1); // block fill (avoid wrap)
flood |= attacker = (attacker << 7) & empty;
flood |= attacker = (attacker << 7) & empty;
flood |= attacker = (attacker << 7) & empty;
flood |= attacker = (attacker << 7) & empty;
flood |= attacker = (attacker << 7) & empty;
flood |= attacker = (attacker << 7) & empty;
return (flood << 7) & ~bitMask<File>(h1);
}
template <>
inline u64 fill7dump<Plus>(u64 attacker, u64 empty) {
return fill7dump<Up> (attacker, empty)
| fill7dump<Down> (attacker, empty)
| fill7dump<Left> (attacker, empty)
| fill7dump<Right>(attacker, empty);
}
template <>
inline u64 fill7dump<Cross>(u64 attacker, u64 empty) {
return fill7dump<LeftUp> (attacker, empty)
| fill7dump<LeftDown> (attacker, empty)
| fill7dump<RightUp> (attacker, empty)
| fill7dump<RightDown>(attacker, empty);
}
template <>
inline u64 fill7dump<Star>(u64 attacker, u64 empty) {
return fill7dump<Plus> (attacker, empty)
| fill7dump<Cross>(attacker, empty);
}
inline Index fileIndex(const Index squareIndex) {
assert(squareIndex < 64);
return squareIndex & 7;
}
inline Index rankIndex(const Index squareIndex) {
assert(squareIndex < 64);
return squareIndex >> 3;
}
inline Index squareIndex(const Index fileIndex, const Index rankIndex) {
assert(fileIndex < 8);
assert(rankIndex < 8);
return 8 * rankIndex + fileIndex;
}
inline u64 lineMask(Index sq1, Index sq2) {
int r1 = static_cast<int>(rankIndex(sq1));
int r2 = static_cast<int>(rankIndex(sq2));
int f1 = static_cast<int>(fileIndex(sq1));
int f2 = static_cast<int>(fileIndex(sq2));
assert((-1 < r1) && (r1 < 8));
assert((-1 < r2) && (r2 < 8));
assert((-1 < f1) && (f1 < 8));
assert((-1 < f2) && (f2 < 8));
if (r1 == r2) { return bitMask<Rank>(sq1); }
if (f1 == f2) { return bitMask<File>(sq1); }
if ((r1 - r2) == (f1 - f2)) { return bitMask<Diag>(sq1); }
if ((r1 - r2) == (f2 - f1)) { return bitMask<AntiDiag>(sq1); }
assert(false);
return bitMask<Square>(sq1) | bitMask<Square>(sq2);
}
inline u64 bitReverse(u64 x) {
x = (((x & static_cast<u64>(0xAAAAAAAAAAAAAAAA)) >> 1)
| ((x & static_cast<u64>(0x5555555555555555)) << 1));
x = (((x & static_cast<u64>(0xCCCCCCCCCCCCCCCC)) >> 2)
| ((x & static_cast<u64>(0x3333333333333333)) << 2));
x = (((x & static_cast<u64>(0xF0F0F0F0F0F0F0F0)) >> 4)
| ((x & static_cast<u64>(0x0F0F0F0F0F0F0F0F)) << 4));
#ifdef __GNUC__
return __builtin_bswap64(x);
#else
x = (((x & static_cast<u64>(0xFF00FF00FF00FF00)) >> 8)
| ((x & static_cast<u64>(0x00FF00FF00FF00FF)) << 8));
x = (((x & static_cast<u64>(0xFFFF0000FFFF0000)) >> 16)
| ((x & static_cast<u64>(0x0000FFFF0000FFFF)) << 16));
return((x >> 32) | (x << 32));
#endif
}
template <enum location>
inline u64 bitFlip(u64 x);
// Reverses byte order, aka rank 8 <-> rank 1, rank 7 <-> rank 2, ...
template <>
inline u64 bitFlip<Rank>(u64 x) {
#ifdef __GNUC__
return __builtin_bswap64(x);
#elif
x = (((x & static_cast<u64>(0xFF00FF00FF00FF00)) >> 8)
| ((x & static_cast<u64>(0x00FF00FF00FF00FF)) << 8));
x = (((x & static_cast<u64>(0xFFFF0000FFFF0000)) >> 16)
| ((x & static_cast<u64>(0x0000FFFF0000FFFF)) << 16));
return((x >> 32) | (x << 32));
#endif
}
// Reverses bits in bytes, aka file a <-> file h, file b <-> file g, ...
template <>
inline u64 bitFlip<File> (u64 x) {
constexpr u64 a = 0x5555555555555555;
constexpr u64 b = 0x3333333333333333;
constexpr u64 c = 0x0F0F0F0F0F0F0F0F;
x = ((x >> 1) & a) | ((x & a) << 1);
x = ((x >> 2) & b) | ((x & b) << 2);
x = ((x >> 4) & c) | ((x & c) << 4);
return x;
}
// see: https://chessprogramming.org/Flipping_Mirroring_and_Rotating
// Flips bits about the diagonal (transposition of the board)
template <>
inline u64 bitFlip<Diag>(u64 x) {
u64 t; // Temporary Value
constexpr u64 a = 0x5500550055005500;
constexpr u64 b = 0x3333000033330000;
constexpr u64 c = 0x0F0F0F0F00000000;
t = c & (x ^ (x << 28));
x ^= t ^ (t >> 28);
t = b & (x ^ (x << 14));
x ^= t ^ (t >> 14);
t = a & (x ^ (x << 7));
x ^= t ^ (t >> 7);
return x;
}
// Flips bits about the anti-diagonal
template <>
inline u64 bitFlip<AntiDiag>(u64 x) {
u64 t; // Temporary Value
constexpr u64 a = 0xAA00AA00AA00AA00;
constexpr u64 b = 0xCCCC0000CCCC0000;
constexpr u64 c = 0xF0F0F0F00F0F0F0F;
t = x ^ (x << 36);
x ^= c & (t ^ (x >> 36));
t = b & (x ^ (x << 18));
x ^= t ^ (t >> 18);
t = a & (x ^ (x << 9));
x ^= t ^ (t >> 9);
return x;
}
inline Index bitCount(u64 x) {
#ifdef __GNUC__
return __builtin_popcountll(x); // `POPulation COUNT` (Long Long)
#elif
Index count = 0; // counts set bits
// increment count until there are no bits set in x
for (; x; count++) {
x &= x - 1; // unset least significant bit
}
return count;
#endif
}
#ifdef __GNUC__
inline Index bitScanLS(u64 bb) {
return __builtin_ctzll(bb); // Count Trailing Zeros (Long Long)
}
#elif
/**
* `de Brujin` sequence and lookup for bitScanLS and bitScanMS.
*/
constexpr u64 debruijn64Seq = static_cast<u64>(0x03f79d71b4cb0a89);
constexpr Index debruijn64Lookup[64] = {
0, 47, 1, 56, 48, 27, 2, 60,
57, 49, 41, 37, 28, 16, 3, 61,
54, 58, 35, 52, 50, 42, 21, 44,
38, 32, 29, 23, 17, 11, 4, 62,
46, 55, 26, 59, 40, 36, 15, 53,
34, 51, 20, 43, 31, 22, 10, 45,
25, 39, 14, 33, 19, 30, 9, 24,
13, 18, 8, 12, 7, 6, 5, 63
};
/**
* Gets the least significant 1 bit `bitScanLS` and most significant
* `bitScanMS` index on a 64-bit board.
*
* Using a `de Brujin` sequence to index a 1 in a 64-bit word.
*
* @param `bb` 64-bit word (a bit board)
* @condition `bb` != 0
* @return index 0 to 63 of least significant one bit
*
* @see Original authors: `Martin Läuter (1997), Charles E. Leiserson,`
* `Harald Prokop, Keith H. Randall`
* @see https://www.chessprogramming.org/BitScan
*/
inline Index bitScanLS(u64 bb) {
assert(bb != 0);
return debruijn64Lookup[((bb ^ (bb - 1)) * debruijn64Seq) >> 58];
}
inline Index bitScanMS(u64 bb) {
assert(bb != 0);
bb |= bb >> 1;
bb |= bb >> 2;
bb |= bb >> 4;
bb |= bb >> 8;
bb |= bb >> 16;
bb |= bb >> 32;
return debruijn64Lookup[(bb * debruijn64Seq) >> 58];
}
#endif
#endif /* INCLUDE_GUARD_UTILS_H */

View File

@ -0,0 +1,4 @@
PKG_CXXFLAGS += -I'../inst/include' -pthread -DRCPP_RCOUT
SOURCES = $(wildcard *.cpp) $(wildcard ../inst/include/SchachHoernchen/*.cpp)
OBJECTS = $(SOURCES:.cpp=.o)

View File

@ -0,0 +1,196 @@
// Generated by using Rcpp::compileAttributes() -> do not edit by hand
// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
#include "../inst/include/Rchess.h"
#include <Rcpp.h>
using namespace Rcpp;
#ifdef RCPP_USE_GLOBAL_ROSTREAM
Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif
// data_gen
Rcpp::CharacterVector data_gen(const std::string& file, const int sample_size, const float score_min, const float score_max);
RcppExport SEXP _Rchess_data_gen(SEXP fileSEXP, SEXP sample_sizeSEXP, SEXP score_minSEXP, SEXP score_maxSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::string& >::type file(fileSEXP);
Rcpp::traits::input_parameter< const int >::type sample_size(sample_sizeSEXP);
Rcpp::traits::input_parameter< const float >::type score_min(score_minSEXP);
Rcpp::traits::input_parameter< const float >::type score_max(score_maxSEXP);
rcpp_result_gen = Rcpp::wrap(data_gen(file, sample_size, score_min, score_max));
return rcpp_result_gen;
END_RCPP
}
// fen2int
Rcpp::IntegerVector fen2int(const std::vector<Board>& boards);
RcppExport SEXP _Rchess_fen2int(SEXP boardsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::vector<Board>& >::type boards(boardsSEXP);
rcpp_result_gen = Rcpp::wrap(fen2int(boards));
return rcpp_result_gen;
END_RCPP
}
// read_cyclic
Rcpp::CharacterVector read_cyclic(const std::string& file, const int nrows, const int skip, const int start, const int line_len);
RcppExport SEXP _Rchess_read_cyclic(SEXP fileSEXP, SEXP nrowsSEXP, SEXP skipSEXP, SEXP startSEXP, SEXP line_lenSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::string& >::type file(fileSEXP);
Rcpp::traits::input_parameter< const int >::type nrows(nrowsSEXP);
Rcpp::traits::input_parameter< const int >::type skip(skipSEXP);
Rcpp::traits::input_parameter< const int >::type start(startSEXP);
Rcpp::traits::input_parameter< const int >::type line_len(line_lenSEXP);
rcpp_result_gen = Rcpp::wrap(read_cyclic(file, nrows, skip, start, line_len));
return rcpp_result_gen;
END_RCPP
}
// sample_move
Move sample_move(const Board& pos);
RcppExport SEXP _Rchess_sample_move(SEXP posSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Board& >::type pos(posSEXP);
rcpp_result_gen = Rcpp::wrap(sample_move(pos));
return rcpp_result_gen;
END_RCPP
}
// sample_fen
std::vector<Board> sample_fen(const unsigned nr, const unsigned min_depth, const unsigned max_depth);
RcppExport SEXP _Rchess_sample_fen(SEXP nrSEXP, SEXP min_depthSEXP, SEXP max_depthSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const unsigned >::type nr(nrSEXP);
Rcpp::traits::input_parameter< const unsigned >::type min_depth(min_depthSEXP);
Rcpp::traits::input_parameter< const unsigned >::type max_depth(max_depthSEXP);
rcpp_result_gen = Rcpp::wrap(sample_fen(nr, min_depth, max_depth));
return rcpp_result_gen;
END_RCPP
}
// board
Board board(const std::string& fen);
RcppExport SEXP _Rchess_board(SEXP fenSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
rcpp_result_gen = Rcpp::wrap(board(fen));
return rcpp_result_gen;
END_RCPP
}
// print_board
void print_board(const std::string& fen);
RcppExport SEXP _Rchess_print_board(SEXP fenSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
print_board(fen);
return R_NilValue;
END_RCPP
}
// print_moves
void print_moves(const std::string& fen);
RcppExport SEXP _Rchess_print_moves(SEXP fenSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
print_moves(fen);
return R_NilValue;
END_RCPP
}
// print_bitboards
void print_bitboards(const std::string& fen);
RcppExport SEXP _Rchess_print_bitboards(SEXP fenSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< const std::string& >::type fen(fenSEXP);
print_bitboards(fen);
return R_NilValue;
END_RCPP
}
// position
Board position(const Board& pos, const std::vector<std::string>& moves, const bool san);
RcppExport SEXP _Rchess_position(SEXP posSEXP, SEXP movesSEXP, SEXP sanSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const Board& >::type pos(posSEXP);
Rcpp::traits::input_parameter< const std::vector<std::string>& >::type moves(movesSEXP);
Rcpp::traits::input_parameter< const bool >::type san(sanSEXP);
rcpp_result_gen = Rcpp::wrap(position(pos, moves, san));
return rcpp_result_gen;
END_RCPP
}
// perft
void perft(const int depth);
RcppExport SEXP _Rchess_perft(SEXP depthSEXP) {
BEGIN_RCPP
Rcpp::traits::input_parameter< const int >::type depth(depthSEXP);
perft(depth);
return R_NilValue;
END_RCPP
}
// go
Move go(const int depth);
RcppExport SEXP _Rchess_go(SEXP depthSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< const int >::type depth(depthSEXP);
rcpp_result_gen = Rcpp::wrap(go(depth));
return rcpp_result_gen;
END_RCPP
}
// ucinewgame
void ucinewgame();
RcppExport SEXP _Rchess_ucinewgame() {
BEGIN_RCPP
ucinewgame();
return R_NilValue;
END_RCPP
}
// onLoad
void onLoad(const std::string& libname, const std::string& pkgname);
RcppExport SEXP _Rchess_onLoad(SEXP libnameSEXP, SEXP pkgnameSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::string& >::type libname(libnameSEXP);
Rcpp::traits::input_parameter< const std::string& >::type pkgname(pkgnameSEXP);
onLoad(libname, pkgname);
return R_NilValue;
END_RCPP
}
// onUnload
void onUnload(const std::string& libpath);
RcppExport SEXP _Rchess_onUnload(SEXP libpathSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const std::string& >::type libpath(libpathSEXP);
onUnload(libpath);
return R_NilValue;
END_RCPP
}
static const R_CallMethodDef CallEntries[] = {
{"_Rchess_data_gen", (DL_FUNC) &_Rchess_data_gen, 4},
{"_Rchess_fen2int", (DL_FUNC) &_Rchess_fen2int, 1},
{"_Rchess_read_cyclic", (DL_FUNC) &_Rchess_read_cyclic, 5},
{"_Rchess_sample_move", (DL_FUNC) &_Rchess_sample_move, 1},
{"_Rchess_sample_fen", (DL_FUNC) &_Rchess_sample_fen, 3},
{"_Rchess_board", (DL_FUNC) &_Rchess_board, 1},
{"_Rchess_print_board", (DL_FUNC) &_Rchess_print_board, 1},
{"_Rchess_print_moves", (DL_FUNC) &_Rchess_print_moves, 1},
{"_Rchess_print_bitboards", (DL_FUNC) &_Rchess_print_bitboards, 1},
{"_Rchess_position", (DL_FUNC) &_Rchess_position, 3},
{"_Rchess_perft", (DL_FUNC) &_Rchess_perft, 1},
{"_Rchess_go", (DL_FUNC) &_Rchess_go, 1},
{"_Rchess_ucinewgame", (DL_FUNC) &_Rchess_ucinewgame, 0},
{"_Rchess_onLoad", (DL_FUNC) &_Rchess_onLoad, 2},
{"_Rchess_onUnload", (DL_FUNC) &_Rchess_onUnload, 1},
{NULL, NULL, 0}
};
RcppExport void R_init_Rchess(DllInfo *dll) {
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
}

View File

@ -0,0 +1,129 @@
#include <iostream>
#include <iomanip>
#include <string>
#include <fstream>
#include <sstream>
#include <limits>
#include <Rcpp.h>
#include "SchachHoernchen/Board.h"
//' Specialized version of `read_cyclic.cpp` taylored to work in conjunction with
//' `gmlm_chess()` as data generator to provide random draws from a FEN data set
//' with scores filtered to be in in the range `score_min` to `score_max`.
//'
// [[Rcpp::export(name = "data.gen", rng = true)]]
Rcpp::CharacterVector data_gen(
const std::string& file,
const int sample_size,
const float score_min,
const float score_max
) {
// Check parames
if (sample_size < 1) {
Rcpp::stop("`sample_size` must be positive");
}
if (score_min >= score_max) {
Rcpp::stop("`score_min` must be strictly smaller than `score_max`");
}
// open FEN data set file
std::ifstream input(file);
if (!input) {
Rcpp::stop("Opening file '%s' failed", file);
}
// set the read from stream position to a random line
input.seekg(0, std::ios::end);
unsigned long seek = unif_rand() * input.tellg();
input.seekg(seek);
// from random position set stream position to line start (if not over shot)
if (!input.eof()) {
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
}
// Ensure (in any case) we are at a legal position (recycle)
if (input.eof()) {
input.seekg(0);
}
// Allocate output sample
Rcpp::CharacterVector sample(sample_size);
// Read and filter lines from FEN data base file
std::string line, fen;
float score;
Board pos;
int sample_count = 0, retry_count = 0, reject_count = 0;
while (sample_count < sample_size) {
// Check for user interupt (that is, allows from `R` to interupt execution)
R_CheckUserInterrupt();
// Read line, in case of failure retry from start of file (recycling)
if (!std::getline(input, line)) {
input.clear();
input.seekg(0);
if (!std::getline(input, line)) {
// another failur is fatal
Rcpp::stop("Recycline lines in file '%s' failed", file);
}
}
// Check for empty line, treated as a partial error which we retry a few times
if (line.empty()) {
if (++retry_count > 10) {
Rcpp::stop("Retry count exceeded after reading empty line in '%s'", file);
} else {
continue;
}
}
// Split candidat line into FEN and score
std::stringstream candidat(line);
std::getline(candidat, fen, ';');
candidat >> score;
if (candidat.fail()) {
// If this failes, the FEN data base is ill formed!
Rcpp::stop("Ill formated FEN data base file '%s'", file);
}
// parse FEN to filter only positions with white to move
bool parseError = false;
pos.init(fen, parseError);
if (parseError) {
Rcpp::stop("Retry count exceeded after illegal FEN '%s'", fen);
}
// Filter white to move positions
if (pos.sideToMove() == piece::black) {
reject_count++;
continue;
}
// filter scores out of slice
if (score < score_min || score_max <= score) {
reject_count++;
continue;
}
// Avoid infinite loop
if (reject_count > 1000 * sample_size) {
Rcpp::stop("Too many rejections, stop to avoid infinite loop");
}
// Everythings succeeded and ge got an appropriate sample in requested range
sample[sample_count++] = fen;
// skip lines (ensures independent draws based on games being independent)
if (input.eof()) {
input.seekg(0);
}
for (int s = 0; s < 256; ++s) {
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
if (input.eof()) {
input.seekg(0);
}
}
}
return sample;
}

View File

@ -0,0 +1,52 @@
#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/types.h"
#include "SchachHoernchen/utils.h"
#include "SchachHoernchen/Board.h"
//' Convert a legal FEN string to a 3D binary (integer with 0-1 entries) array
// [[Rcpp::export(rng = false)]]
Rcpp::IntegerVector fen2int(const std::vector<Board>& boards) {
// Initialize empty chess board as a 3D binary tensor
Rcpp::IntegerVector bitboards(8 * 8 * 12 * (int)boards.size());
// Set dimension and dimension names (required this way since `Rcpp::Dimension`
// does _not_ support 4D arrays)
auto dims = Rcpp::IntegerVector({ 8, 8, 12, (int)boards.size() });
bitboards.attr("dim") = dims;
bitboards.attr("dimnames") = Rcpp::List::create(
Rcpp::Named("rank") = Rcpp::CharacterVector::create("8", "7", "6", "5", "4", "3", "2", "1"),
Rcpp::Named("file") = Rcpp::CharacterVector::create("a", "b", "c", "d", "e", "f", "g", "h"),
Rcpp::Named("piece") = Rcpp::CharacterVector::create(
"P", "N", "B", "R", "Q", "K", // White Pieces (Upper Case)
"p", "n", "b", "r", "q", "k" // Black Pieces (Lower Case)
),
R_NilValue
);
// Index to color/piece mapping (more robust)
enum piece colorLoopup[2] = { white, black };
enum piece pieceLookup[6] = { pawn, knight, bishop, rook, queen, king };
// Set for every piece the corresponding pit positions, note the
// "transposition" of the indexing from SchachHoernchen's binary indexing
// scheme to the 3D array indexing of ranks/files.
for (int i = 0; i < boards.size(); ++i) {
const Board& pos = boards[i];
for (int color = 0; color < 2; ++color) {
for (int piece = 0; piece < 6; ++piece) {
int slice = 6 * color + piece;
u64 bb = pos.bb(colorLoopup[color]) & pos.bb(pieceLookup[piece]);
for (; bb; bb &= bb - 1) {
// Get BitBoard index
int index = bitScanLS(bb);
// Transpose to align with printing as a Chess Board
index = ((index & 7) << 3) | ((index & 56) >> 3);
bitboards[768 * i + 64 * slice + index] = 1;
}
}
}
}
return bitboards;
}

View File

@ -0,0 +1,71 @@
// #include <Rcpp.h>
// #include "SchachHoernchen/utils.h"
// #include "SchachHoernchen/Board.h"
// #include "SchachHoernchen/uci.h"
// // [[Rcpp::export(name = "print.board", rng = false)]]
// void print_board(
// const Board& board,
// const bool check = true,
// const bool attacked = false,
// const bool pinned = false,
// const bool checkers = false
// ) {
// using Rcpp::Rcout;
// // Extract some properties
// piece color = board.sideToMove();
// piece enemy = static_cast<piece>(!color);
// u64 empty = ~(board.bb(piece::white) | board.bb(piece::black));
// u64 cKing = board.bb(king) & board.bb(color);
// // Construct highlight mask
// u64 attackMask = board.attacks(enemy, empty);
// u64 mask = 0;
// mask |= check ? attackMask & board.bb(color) & board.bb(king) : 0;
// mask |= attacked ? attackMask & ~board.bb(color) : 0;
// mask |= pinned ? board.pinned(enemy, cKing, empty) : 0;
// mask |= checkers ? board.checkers(enemy, cKing, empty) : 0;
// // print the board to console
// Rcout << "FEN: " << board.fen() << '\n';
// for (Index line = 0; line < 17; line++) {
// if (line % 2) {
// Index rankIndex = line / 2;
// Rcout << static_cast<char>('8' - rankIndex);
// for (Index fileIndex = 0; fileIndex < 8; fileIndex++) {
// Index squareIndex = 8 * rankIndex + fileIndex;
// if (bitMask<Square>(squareIndex) & mask) {
// if (board.piece(squareIndex)) {
// if (board.color(squareIndex) == black) {
// // bold + italic + underline + black (blue)
// Rcout << " | \033[1;3;4;94m";
// } else {
// // bold + italic + underline (+ white)
// Rcout << " | \033[1;3;4m";
// }
// Rcout << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else {
// Rcout << " | .";
// }
// } else if (board.color(squareIndex) == black) {
// Rcout << " | \033[1m\033[94m" // bold + blue (black)
// << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else if (board.color(squareIndex) == white) {
// Rcout << " | \033[1m\033[97m" // bold + white
// << UCI::formatPiece(board.piece(squareIndex))
// << "\033[0m";
// } else {
// Rcout << " | ";
// }
// }
// Rcout << " |";
// } else {
// Rcout << " +---+---+---+---+---+---+---+---+";
// }
// Rcout << "\033[0K\n"; // clear rest of line (remove potential leftovers)
// }
// Rcout << " a b c d e f g h" << std::endl;
// }

View File

@ -0,0 +1,94 @@
#include <iostream>
#include <iomanip>
#include <string>
#include <fstream>
#include <sstream>
#include <limits>
#include <Rcpp.h>
//' Reads lines from a text file with recycling.
//'
// [[Rcpp::export(name = "read.cyclic", rng = false)]]
Rcpp::CharacterVector read_cyclic(
const std::string& file,
const int nrows = 1000,
const int skip = 100,
const int start = 1,
const int line_len = 64
) {
// `unsigned` is not "properly" checked by `Rcpp` to _not_ be negative.
// An implicit cast to unsigned makes negative integers gigantic, not wanted!
if (skip < 0) {
Rcpp::stop("`skip` must be non-negative.");
}
if (nrows < 1 || start < 1 || line_len < 1) {
Rcpp::stop("`start`, `nrows` and `line_len` must be positive.");
}
// Open input file
std::ifstream input(file);
if (!input) {
Rcpp::stop("Opening file '%s' failed", file);
}
// Handle different versions of start positions, if the start position is
// large, we guess a value to skip via a simple heuristic and go with it
if (1000 < start) {
// Skip (approx) `start` lines
input.seekg(0, std::ios::end); // get to end of file
unsigned long size = static_cast<unsigned long>(input.tellg());
unsigned long seek = static_cast<unsigned long>(line_len) * (
static_cast<unsigned long>(start) - 1UL
);
seek = seek % size;
// Now seek to the approx line nr. with recycling
input.seekg(seek);
// read till end of line to have a proper start of line position
if (!input.eof()) {
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
}
// in the occastional case of ending at a last line (recycle)
if (input.eof()) { input.seekg(0); }
} else {
// Skip (exactly) `start` lines
for (int line_nr = 1; line_nr < start; ++line_nr) {
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
// In case of reaching the end of file, restart (recycle)
if (input.eof()) { input.seekg(0); }
}
}
// Create character vector for output lines
Rcpp::CharacterVector lines(nrows);
// Read one line and skip multiple till `nrows` lines are included
std::string line;
for (int i = 0; i < nrows; ++i) {
// Read one line
if (std::getline(input, line)) {
lines[i] = line;
} else {
// retry from start of file, may occure in last empty line of file
input.clear();
input.seekg(0);
if (std::getline(input, line)) {
lines[i] = line;
} else {
// Another failur is fatal
Rcpp::stop("Recycling lines in file '%s' failed", file);
}
}
// recycle if at end of file (always check)
if (input.eof()) { input.seekg(0); }
// skip lines (with recycling)
for (int s = 0; s < skip; ++s) {
input.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
if (input.eof()) { input.seekg(0); }
}
}
return lines;
}

View File

@ -0,0 +1,92 @@
#include <vector>
#include <Rcpp.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
//' Samples a legal move from a given position
// [[Rcpp::export("sample.move", rng = true)]]
Move sample_move(const Board& pos) {
// RNG for continuous uniform X ~ U[0, 1]
auto runif = Rcpp::stats::UnifGenerator(0, 1);
// RNG for discrete uniform X ~ U[0; max - 1]
auto rindex = [&runif](const std::size_t max) {
// sample random index untill in range [0; max - 1]
unsigned index;
do {
index = static_cast<unsigned>(runif() * static_cast<double>(max));
} while (index >= max);
return index;
};
// generate all legal moves
MoveList moves;
pos.moves(moves);
// check if there are any moves to sample from
if (moves.empty()) {
Rcpp::stop("Attempt to same legal move from terminal node");
}
return moves[rindex(moves.size())];
}
//' Samples a random FEN (position) by applying `ply` random moves to the start
//' position.
//'
//' @param nr number of positions to sample
//' @param min_depth minimum number of random ply's to generate random positions
//' @param max_depth maximum number of random ply's to generate random positions
// [[Rcpp::export("sample.fen", rng = true)]]
std::vector<Board> sample_fen(const unsigned nr,
const unsigned min_depth = 4, const unsigned max_depth = 20
) {
// Parameter validation
if (min_depth > max_depth) {
Rcpp::stop("max_depth must be bigger equal than min_depth");
}
if (128 < max_depth) {
Rcpp::stop("max_depth exceeded maximum value 128");
}
// RNG for continuous uniform X ~ U[0, 1]
auto runif = Rcpp::stats::UnifGenerator(0, 1);
// RNG for discrete uniform X ~ U[0; max - 1]
auto rindex = [&runif](const std::size_t max) {
// Check if max is bigger than zero, cause we can not sample from the
// empty set
if (max == 0) {
Rcpp::stop("Attempt to sample random index < 0.");
}
// sample random index untill in range [0; max - 1]
unsigned index = max;
while (index >= max) {
index = static_cast<unsigned>(runif() * static_cast<double>(max));
}
return index;
};
// Setup response vector
std::vector<Board> fens;
fens.reserve(nr);
// Sample FENs
MoveList moves;
for (unsigned i = 0; i < nr; ++i) {
Board pos; // start position
unsigned depth = static_cast<unsigned>(runif() * (max_depth - min_depth));
depth += min_depth;
for (unsigned ply = 0; ply < depth; ++ply) {
moves.clear();
pos.moves(moves);
if (moves.size()) {
pos.make(moves[rindex(moves.size())]);
} else {
break;
}
}
fens.push_back(pos);
}
return fens;
}

View File

@ -0,0 +1,146 @@
// UCI (Univeral Chess Interface) `R` binding to the `SchachHoernchen` engine
#include <string>
#include <vector>
#include <sstream>
#include <iterator>
#include <Rcpp.h>
#include "SchachHoernchen/Move.h"
#include "SchachHoernchen/Board.h"
#include "SchachHoernchen/uci.h"
#include "SchachHoernchen/search.h"
namespace UCI_State {
std::vector<Board> game(1);
};
//' Converts a FEN string to a Board (position) or return the current internal state
// [[Rcpp::export(rng = false)]]
Board board(const std::string& fen = "") {
if (fen == "") {
return UCI_State::game.back();
} else if (fen == "startpos") {
return Board();
}
Board pos;
bool parseError = false;
pos.init(fen, parseError);
if (parseError) {
Rcpp::stop("Parse parsing FEN");
}
return pos;
}
// [[Rcpp::export(name = "print.board", rng = false)]]
void print_board(const std::string& fen = "") {
UCI::printBoard(board(fen));
}
// [[Rcpp::export(name = "print.moves", rng = false)]]
void print_moves(const std::string& fen = "") {
UCI::printMoves(board(fen));
}
// [[Rcpp::export(name = "print.bitboards", rng = false)]]
void print_bitboards(const std::string& fen = "") {
UCI::printBitBoards(board(fen));
}
// [[Rcpp::export(rng = false)]]
Board position(
const Board& pos,
const std::vector<std::string>& moves,
const bool san = false
) {
// Build UCI command (without "position")
std::stringstream cmd;
cmd << "fen " << pos.fen() << " ";
if (moves.size()) {
cmd << "moves ";
std::copy(moves.begin(), moves.end(), std::ostream_iterator<std::string>(cmd, " "));
}
// set UCI internal flag to interprate move input in from-to format or SAN
UCI::readSAN = san;
// and invoke UCI position command handler on the internal game state
bool parseError = false;
UCI::position(UCI_State::game, cmd, parseError);
if (parseError) {
Rcpp::stop("Parse Error");
}
return UCI_State::game.back();
}
// [[Rcpp::export(rng = false)]]
void perft(const int depth = 6) {
// Enforce a depth limit, this is very restrictive but we only want to
// use it as a toolbox and _not_ as a strong chess engine
if (8 < depth) {
Rcpp::stop("In `R` search is limited to depth 8");
} else if (depth <= 0) {
cout << "Nodes searched: 0" << std::endl;
return;
}
// Get current set position
Board pos(UCI_State::game.back());
// Get all legal moves
MoveList moves;
pos.moves(moves);
// Setup counter for total nr. of moves
Index totalCount = 0;
Board copy(pos);
for (Move move : moves) {
// continue traversing
pos.make(move);
Index nodeCount = Search::perft_subroutine(pos, depth - 1);
totalCount += nodeCount;
pos = copy; // unmake move
// report moves of node
cout << move << ": " << nodeCount << std::endl;
}
cout << std::endl << "Nodes searched: " << totalCount << std::endl;
}
// [[Rcpp::export(rng = false)]]
Move go(
const int depth = 6
) {
// Enforce a depth limit, this is very restrictive but we only want to
// use it as a toolbox and _not_ as a strong chess engine
if (8 < depth) {
Rcpp::stop("In `R` search is limited to depth 8");
}
// Setup search configuration
Search::State config;
config.depth = depth < 1 ? 1 : depth;
// sets worker thread stop condition to false (before dispatch) which in this
// context is the main thread. Only need to ensure its running.
Search::isRunning.store(true, std::memory_order_release);
// Construct a search object
Search::PVS<Board> search(UCI_State::game, config);
// and start the search
search();
// return best move
return search.bestMove();
}
// [[Rcpp::export(rng = false)]]
void ucinewgame() {
Search::newgame();
UCI_State::game.clear();
UCI_State::game.emplace_back();
}

View File

@ -0,0 +1,21 @@
// Startup and unload routines for initialization and cleanup, see: `R/zzz.R`
#include <Rcpp.h>
#include "SchachHoernchen/search.h"
// [[Rcpp::export(".onLoad")]]
void onLoad(const std::string& libname, const std::string& pkgname) {
// Initialize search (Transposition Table)
Search::init();
// Report search initialization
Rcpp::Rcout << "info string search initialized" << std::endl;
}
// [[Rcpp::export(".onUnload")]]
void onUnload(const std::string& libpath) {
// Cleanup any outstanding or running/finished worker tasks
Search::stop();
// Report shutdown (stopping any remaining searches)
Rcpp::Rcout << "info string shutdown" << std::endl;
}

View File

@ -0,0 +1,215 @@
#' Specialized version of `gmlm_ising()`.
#'
#' Theroetically, equivalent to `gmlm_ising()` except the it uses a stochastic
#' gradient descent version of RMSprop instead of classic gradient descent.
#' Other differences are puerly of technical nature.
#'
#' @param data_gen data generator, samples from the data set conditioned on a
#' slice value `y.min` to `y.max`. Function signature
#' `function(batch.size, y.min, y.max)` with return value `X`, a
#' `8 x 8 x 12 x batch.size` 4D array.
#' @param fun_y known functions of scalar `y`, returning a 3D/4D tensor
#' @param score_breaks numeric vector of two or more unique cut points, the cut
#' points are the interval bounds specifying the slices of `y`.
#' @param nr_threads integer, nr. of threads used by `ising_m2()`
#' @param mcmc_samples integer, nr. of Monte-Carlo Chains passed to `ising_m2()`
#' @param slice_size integer, size of sub-samples generated by `data_gen` for
#' every slice. The batch size of the for every iteration is then equal to
#' `slice_size * (length(score_breaks) - 1L)`.
#' @param max_iter maximum number of iterations for gradient optimization
#' @param patience integer, break condition parameter. If the approximated loss
#' doesn't improve over `patience` iterations, then stop.
#' @param step_size numeric, meta parameter for RMSprop for gradient scaling
#' @param eps numeric, meta parameter for RMSprop avoiding divition by zero in
#' the parameter update rule of RMSprop
#' @param save_point character, file name pattern for storing and retrieving
#' optimization save points. Those save points allow to stop the method and
#' resume optimization later from the last save point.
#'
gmlm_chess <- function(
data_gen,
fun_y,
score_breaks = c(-5.0, -3.0, -2.0, -1.0, -0.5, -0.2, 0.2, 0.5, 1.0, 2.0, 3.0, 5.0),
nr_threads = 8L,
mcmc_samples = 10000L,
slice_size = 512L,
max_iter = 1000L,
patience = 25L,
step_size = 1e-3,
eps = sqrt(.Machine$double.eps),
save_point = "gmlm_chess_save_point_%s.Rdata"
) {
# build intervals from score break points
score_breaks <- sort(score_breaks)
score_min <- head(score_breaks, -1)
score_max <- tail(score_breaks, -1)
score_means <- (score_min + score_max) / 2
# build Omega constraint, that is the set of impossible combinations
# (including self interactions) due to the rules of chess
Omega_const <- local({
# One piece per square
diag_offset <- abs(.row(c(768, 768)) - .col(c(768, 768)))
Omega_const <- !diag(768) & ((diag_offset %% 64L) == 0L)
# One King per color
Omega_const <- Omega_const | kronecker(diag(1:12 %in% c(6, 12)), !diag(64), `&`)
# no pawns on rank 1 or rank 8
pawn_const <- tcrossprod(as.vector(`[<-`(matrix(0L, 8, 8), c(1, 8), , 1L)), rep(1L, 64))
pawn_const <- kronecker(`[<-`(matrix(0, 12, 12), c(1, 7), , 1), pawn_const)
which(Omega_const | (pawn_const | t(pawn_const)))
})
# Check if there is a save point (load from save)
load_point <- if (is.character(save_point)) {
sort(list.files(pattern = sprintf(save_point, ".*")), decreasing = TRUE)
} else {
character(0)
}
# It a load point is found, resume from save point, otherwise initialize
if (length(load_point)) {
load_point <- load_point[[1]]
cat(sprintf("Resuming from save point '%s'\n", load_point),
"(to restart delete/rename the save points)\n")
load(load_point)
# Fix `iter`, save after increment
iter <- iter - 1L
} else {
# draw initial sample to be passed to the normal GMLM estimator for initial `betas`
X <- Reduce(c, Map(data_gen, slice_size, score_min, score_max))
dim(X) <- c(8L, 8L, 12L, slice_size * length(score_means))
F <- fun_y(rep(score_means, each = slice_size))
# set object dimensions (`dimX` is constant, `dimF` depends on `fun_y` arg)
dimX <- c(8L, 8L, 12L)
dimF <- dim(F)[1:3]
# Initial values for `betas` are the tensor normal GMLM estimates
betas <- gmlm_tensor_normal(X, F)$betas
# and initial values for `Omegas`, based on the same first "big" sample
Omegas <- Map(function(mode) {
n <- prod(dim(X)[-mode])
prob2 <- mcrossprod(X, mode = mode) / n
prob2[prob2 == 0] <- 1 / n
prob2[prob2 == 1] <- (n - 1) / n
prob1 <- diag(prob2)
`prob1^2` <- outer(prob1, prob1)
`diag<-`(log(((1 - `prob1^2`) / `prob1^2`) * prob2 / (1 - prob2)), 0)
}, 1:3)
# Initial sample `(X, F)` no longer needed, remove
rm(X, F)
# Initialize gradients and aggregated mean squared gradients
grad2_betas <- Map(array, 0, Map(dim, betas))
grad2_Omegas <- Map(array, 0, Map(dim, Omegas))
# initialize optimization tracker for break condition
last_loss <- Inf
non_improving <- 0L
iter <- 0L
}
# main optimization loop
while ((iter <- iter + 1L) <= max_iter) {
# At beginning of every iteration, store current state in a save point.
# This allows to resume optimization from the last save point.
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, non_improving, iter,
file = sprintf(save_point, sprintf("%06d", iter - 1L))))
}
# start timing for this iteration (this is precise enough)
start_time <- proc.time()[["elapsed"]]
# full Omega (with constraint elements set to zero) needed to conditional
# parameters of the Ising model to compute (approx) the second moment
Omega <- `[<-`(Reduce(kronecker, rev(Omegas)), Omega_const, 0)
# Gradient and negative log-likelihood approximation
loss <- 0 # neg. log-likelihood
grad_betas <- Map(matrix, 0, dimX, dimF) # grads for betas
R2 <- array(0, dim = c(dimX, dimX)) # residuals
# for every score slice
for (i in seq_along(score_means)) {
# function of `y` being the score slice mean (only 3D, same for all obs.)
F <- `dim<-`(fun_y(score_means[i]), dimF)
# compute parameters of (slice) conditional Ising model
params <- `diag<-`(Omega, as.vector(mlm(F, betas)))
# second moment of `X | Y = score_means[i]`
m2 <- ising_m2(params, use_MC = TRUE, nr_threads = nr_threads, nr_samples = mcmc_samples)
# draw random sample from current slice `vec(X) | Y in (score_min, score_max]`
# with columns being the vectorized observations `vec(X)`.
matX <- `dim<-`(data_gen(slice_size, score_min[i], score_max[i]), c(prod(dimX), slice_size))
# accumulate (approx) neg. log-likelihood
loss <- loss - (sum(matX * (params %*% matX)) + slice_size * attr(m2, "log_prob_0"))
# Slice residuals (second order `resid2` and actual residuals `resid1`)
resid2 <- tcrossprod(matX) - slice_size * m2
resid1 <- `dim<-`(diag(resid2), dimX)
# accumulate residuals
R2 <- R2 + as.vector(resid2)
# and the beta gradients
grad_betas <- Map(`+`, grad_betas, Map(function(mode) {
mcrossprod(resid1, mlm(slice_size * F, betas[-mode], (1:3)[-mode]), mode)
}, 1:3))
}
# finaly, finish gradient computation with gradients for `Omegas`
grad_Omegas <- Map(function(mode) {
grad <- mlm(kronperm(R2), Map(as.vector, Omegas[-mode]), (1:3)[-mode], transposed = TRUE)
`dim<-`(grad, dim(Omegas[[mode]]))
}, 1:3)
# Update tracker for break condition
non_improving <- max(0L, non_improving - 1L + 2L * (last_loss < loss))
loss_last <- loss
# check break condition
if (non_improving > patience) { break }
# accumulate root mean squared gradients
grad2_betas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_betas, grad_betas)
grad2_Omegas <- Map(function(g2, g) 0.9 * g2 + 0.1 * (g * g), grad2_Omegas, grad_Omegas)
# Update Parameters
betas <- Map(function(beta, grad, m2) {
beta + (step_size / (sqrt(m2) + eps)) * grad
}, betas, grad_betas, grad2_betas)
Omegas <- Map(function(Omega, grad, m2) {
Omega + (step_size / (sqrt(m2) + eps)) * grad
}, Omegas, grad_Omegas, grad2_Omegas)
# Log progress
cat(sprintf("iter: %4d, time for iter: %d [s], loss: %f\n",
iter, round(proc.time()[["elapsed"]] - start_time), loss))
}
# Save a final (terminal) save point
if (is.character(save_point)) {
suspendInterrupts(save(
dimX, dimF,
betas, Omegas,
grad2_betas, grad2_Omegas,
last_loss, non_improving, iter,
file = sprintf(save_point, "final")))
}
structure(
list(betas = betas, Omegas = Omegas),
iter = iter, loss = loss
)
}

View File

@ -0,0 +1,144 @@
#include <iostream>
#include <iomanip>
#include <string>
#include <fstream>
#include <sstream>
#include "utils.h"
#include "Board.h"
#include "Move.h"
#include "search.h"
#include "uci.h"
static const std::string usage{"usage: pgn2fen [--scored] [<input>]"};
// Convert PGN (Portable Game Notation) input stream to single FENs
// streamed to stdout
void pgn2fen(std::istream& input, const bool only_scored) {
// Instantiate Boards, the start of every game as well as the current state
// of the Board while processing a PGN game
Board startpos, pos;
// Read input line by line
std::string line;
while (std::getline(input, line)) {
// Skip empty and metadata lines (every PGN game starts with "<nr>.")
if (line.empty() || line.front() == '[') {
continue;
}
// Reset position to the start position, every game starts here!
pos = startpos;
// Read game content (assuming one line is the entire game)
std::istringstream game(line);
std::string count, san, token, eval;
while (game >> count >> san >> token) {
// Consume/Parse PGN comments
if (only_scored) {
// consume the comment and search for an evaluation
bool has_score = false;
while (game >> token) {
// Search for evaluation token (position score _after_ the move)
if (token == "[%eval") {
game >> eval;
eval.pop_back(); // delete trailing ']'
has_score = true;
// Consume the remainder of the comment (ignore it)
std::getline(game, token, '}');
break;
} else if (token == "}") {
break;
}
}
// In case of not finding an evaluation, skip the game (_not_ an error)
if (!has_score) {
break;
}
} else {
// Consume the remainder of the comment (ignore it)
std::getline(game, token, '}');
}
// Perform move
bool parseError = false;
Move move = UCI::parseSAN(san, pos, parseError);
if (parseError) {
std::cerr << "[ Error ] Parsing '" << san << "' at position '"
<< pos.fen() << "' failed." << std::endl;
}
move = pos.isLegal(move); // validate legality and extend move info
if (move) {
pos.make(move);
} else {
std::cerr << "[ Error ] Encountered illegal move '" << san
<< " (" << move
<< ") ' at position '" << pos.fen() << "'." << std::endl;
break;
}
// Write positions
if (only_scored) {
// Ingore "check mate in" scores (not relevant for eval training)
// Do this after "make move" in situations where the check mate
// was overlooked, leading to new positions
if (eval.length() && eval[0] == '#') {
continue;
}
// Otherwise, classic eval score to be parsed in centipawns
std::cout << pos.fen() << "; " << eval << '\n';
} else {
// Write only the position FEN
std::cout << pos.fen() << '\n';
}
}
}
}
int main(int argn, char* argv[]) {
// Setup control variables
bool only_scored = false;
std::string file = "";
// Parse command arguments
switch (argn) {
case 1:
break;
case 2:
if (std::string("--scored") == argv[1]) {
only_scored = true;
} else {
file = argv[1];
}
break;
case 3:
if (std::string("--scored") != argv[1]) {
std::cout << usage << std::endl;
return 1;
}
only_scored = true;
file = argv[2];
break;
default:
std::cout << usage << std::endl;
return 1;
}
// Invoke converter ether with file input or stdin
if (file == "") {
pgn2fen(std::cin, only_scored);
} else {
// Open input file
std::ifstream input(file);
if (!input) {
std::cerr << "Error opening '" << file << "'" << std::endl;
return 1;
}
pgn2fen(input, only_scored);
}
return 0;
}

View File

@ -0,0 +1,31 @@
#!/bin/bash
# Data set name: Chess games from the Lichess Data Base for standard rated games
# in November 2023
data=lichess_db_standard_rated_2023-11
# Check if file exists and download iff not
if [ -f "${data}.fen" ]; then
echo "File '${data}.fen' already exists, assuming job already done."
echo "To rerun delete (rename) the files '${data}.pgn.zst' and/or '${data}.fen'"
else
# First, compile `png2fen`
make pgn2fen
# Download the PGN data base via `wegt` if not found.
# The flag `-q` suppresses `wget`s own output and `-O-` tells `wget` to
# stream the downloaded file to `stdout`.
# Otherwise, use the file on disk directly.
# Decompress the stream with `zstdcat` (no temporary files)
# The uncompressed PGN data is then piped into `pgn2fen` which converts
# the PGN data base into a list of FEN strings while filtering only
# positions with evaluation. The `--scored` parameter specifies to extract
# a position evaluation from the PGN and ONLY write positions with scores.
# That is, positions without a score are removed!
if [ -f "${data}.pgn.zst" ]; then
zstdcat ${data}.pgn.zst | ./pgn2fen --scored > ${data}.fen
else
wget -qO- https://database.lichess.org/standard/${data}.pgn.zst \
| zstdcat | ./pgn2fen --scored > ${data}.fen
fi
fi