Add Zobrist hashing

This commit is contained in:
stefiosif
2024-10-06 03:31:33 +03:00
parent 70ba30ba1a
commit 28409d203a
8 changed files with 340 additions and 20 deletions

View File

@@ -290,6 +290,10 @@ pub enum Color {
}
impl Color {
pub const fn idx(self) -> usize {
self as usize
}
pub const fn opponent(self) -> Self {
match self {
Self::White => Self::Black,

View File

@@ -27,11 +27,13 @@ pub fn from_fen(fen: &str) -> Result<Game, FenError> {
));
let mailbox = Mailbox::new_from_board(&board);
let hash = zobrist_keys().calculate_hash(&board);
Ok(Game {
board,
history: History::new(),
mailbox,
hash,
})
}
@@ -132,6 +134,7 @@ use std::collections::HashMap;
use super::history::History;
use super::mailbox::Mailbox;
use super::zobrist::zobrist_keys;
fn en_passant_square(square: &str) -> Result<Option<usize>, FenError> {
let mut sqr = square.chars();

View File

@@ -10,6 +10,7 @@ use super::{
mailbox::Mailbox,
square::Square,
state::Castle,
zobrist::{zobrist_keys, ZobristHash},
};
impl PartialEq for Game {
@@ -25,6 +26,7 @@ pub struct Game {
pub board: Board,
pub history: History,
pub mailbox: Mailbox,
pub hash: ZobristHash,
}
impl Game {
@@ -33,6 +35,7 @@ impl Game {
board: Board::new(),
history: History::new(),
mailbox: Mailbox::new_from_board(&Board::new()),
hash: zobrist_keys().calculate_hash(&Board::new()),
}
}
@@ -49,32 +52,43 @@ impl Game {
.push_move_parameters(MoveParameters::build(self, mv));
let board = &mut self.board;
let hash = &mut self.hash;
let mailbox = &mut self.mailbox;
let color = board.state.current_player();
let pawn_move = board.is_pawn_move(mv.src);
let mut en_passant_square = None;
let capture_square = match color {
Color::White => mv.dst - 8,
Color::Black => mv.dst + 8,
};
let old_castling_ability = board.state.castling_ability;
let piece_at_src = mailbox.find_piece_at(mv.src).expect("Expected set piece");
let piece_at_src = mailbox
.find_piece_at(mv.src)
.expect("Expected piece at: {mv.src}");
let piece_at_dst = mailbox.find_piece_at(mv.dst);
match &mv.move_type {
MoveType::Quiet => {
board.move_piece(mv.src, mv.dst, piece_at_src);
mailbox.set_piece_at(mv.dst, mailbox.find_piece_at(mv.src));
hash.update_quiet(mv.src, mv.dst, piece_at_src, color);
mailbox.set_piece_at(mv.dst, Some(piece_at_src));
}
MoveType::Capture => {
let piece_at_dst = piece_at_dst.expect("Expected piece at: {mv.dst}");
board.move_piece(mv.src, mv.dst, piece_at_src);
board.remove_opponent_piece(mv.dst, piece_at_dst.expect("Expected set piece"));
mailbox.set_piece_at(mv.dst, mailbox.find_piece_at(mv.src));
board.remove_opponent_piece(mv.dst, piece_at_dst);
hash.update_capture(mv.src, mv.dst, piece_at_src, piece_at_dst, color);
mailbox.set_piece_at(mv.dst, Some(piece_at_src));
}
MoveType::EnPassant => {
board.move_piece(mv.src, mv.dst, piece_at_src);
let piece_to_remove_sq = match color {
Color::White => mv.dst - 8,
Color::Black => mv.dst + 8,
};
board.remove_opponent_piece(piece_to_remove_sq, PieceType::Pawn);
mailbox.set_piece_at(mv.dst, mailbox.find_piece_at(mv.src));
mailbox.set_piece_at(piece_to_remove_sq, None);
let piece_at_capture = mailbox
.find_piece_at(capture_square)
.expect("Expected piece at: {capture_square}");
board.remove_opponent_piece(capture_square, PieceType::Pawn);
hash.update_en_passant(mv.src, mv.dst, piece_at_src, piece_at_capture, color);
mailbox.set_piece_at(mv.dst, Some(piece_at_src));
mailbox.set_piece_at(capture_square, None);
}
MoveType::DoublePush => {
board.move_piece(mv.src, mv.dst, piece_at_src);
@@ -82,17 +96,36 @@ impl Game {
Color::White => Some(mv.src + 8),
Color::Black => Some(mv.src.saturating_sub(8)),
};
mailbox.set_piece_at(mv.dst, mailbox.find_piece_at(mv.src));
hash.update_double_push(
mv.src,
mv.dst,
piece_at_src,
color,
board.state.en_passant_square,
);
mailbox.set_piece_at(mv.dst, Some(piece_at_src));
}
MoveType::Promotion(promote) => {
board.remove_own_piece(mv.src, piece_at_src);
board.promote_piece(mv.dst, promote);
hash.update_promotion(mv.src, mv.dst, piece_at_src, promote, color);
mailbox.set_piece_at(mv.dst, Some(promote.into_piece_type()));
}
MoveType::PromotionCapture(promote) => {
board.remove_own_piece(mv.src, piece_at_src);
board.remove_opponent_piece(mv.dst, piece_at_dst.expect("Expected set piece"));
board.remove_opponent_piece(
mv.dst,
piece_at_dst.expect("Expected piece at dst: {mv.dst}"),
);
board.promote_piece(mv.dst, promote);
hash.update_promotion_capture(
mv.src,
mv.dst,
piece_at_src,
piece_at_dst.unwrap(),
promote,
color,
);
mailbox.set_piece_at(mv.dst, Some(promote.into_piece_type()));
}
MoveType::Castle => {
@@ -104,7 +137,8 @@ impl Game {
};
board.move_piece(rook_src, rook_dst, PieceType::Rook);
board.state.set_castling_ability(color, Castle::None);
mailbox.set_piece_at(mv.dst, mailbox.find_piece_at(mv.src));
hash.update_castle(mv.src, mv.dst, piece_at_src, rook_src, rook_dst, color);
mailbox.set_piece_at(mv.dst, Some(piece_at_src));
mailbox.set_piece_at(rook_src, None);
mailbox.set_piece_at(rook_dst, Some(PieceType::Rook));
}
@@ -114,6 +148,14 @@ impl Game {
board
.state
.update_game_state(mv, color, pawn_move, en_passant_square);
hash.update_side_to_move_key();
if let Some(old_en_passant) = board.state.en_passant_square {
hash.update_en_passant_keys(old_en_passant);
}
hash.update_castling_ability_keys(old_castling_ability, board.state.castling_ability);
}
pub fn unmake_move(&mut self) {
@@ -127,6 +169,10 @@ impl Game {
board.state.revert_full_move(color_before_move);
board.state.en_passant_square = move_parameters.en_passant_square;
if let Some(hash) = move_parameters.zobrist_hash {
self.hash = hash;
}
if let Some(new_castling_ability) = move_parameters.castling_ability {
board.state.castling_ability = new_castling_ability;
}

View File

@@ -5,6 +5,7 @@ use super::{
game::Game,
mailbox::Mailbox,
state::{Castle, State},
zobrist::ZobristHash,
};
#[derive(Debug, Clone, PartialEq, Eq)]
@@ -37,6 +38,7 @@ pub struct MoveParameters {
pub castling_ability: Option<[Castle; 2]>,
pub en_passant_square: Option<usize>,
pub halfmove_clock: Option<u8>,
pub zobrist_hash: Option<ZobristHash>,
}
impl MoveParameters {
@@ -49,16 +51,18 @@ impl MoveParameters {
castling_ability: None,
en_passant_square: None,
halfmove_clock: None,
zobrist_hash: None,
}
}
pub fn build(game: &Game, mv: &Move) -> Self {
let mut move_parameters = Self::new();
move_parameters.add_move(*mv);
move_parameters.add_irreversible_parameters(game.board.state);
move_parameters.add_irreversible_parameters(&game.board.state);
move_parameters.add_moved_piece(&game.mailbox, mv);
move_parameters.add_captured_piece(&game.mailbox, mv);
move_parameters.add_promoted_piece(mv);
move_parameters.add_zobrist_hash(&game.hash);
move_parameters
}
@@ -67,6 +71,12 @@ impl MoveParameters {
self.mv = Some(mv);
}
fn add_irreversible_parameters(&mut self, state: &State) {
self.castling_ability = Some(state.castling_ability);
self.en_passant_square = state.en_passant_square;
self.halfmove_clock = Some(state.halfmove_clock);
}
fn add_moved_piece(&mut self, mailbox: &Mailbox, mv: &Move) {
self.moved_piece = mailbox.find_piece_at(mv.src);
}
@@ -83,10 +93,8 @@ impl MoveParameters {
}
}
fn add_irreversible_parameters(&mut self, state: State) {
self.castling_ability = Some(state.castling_ability);
self.en_passant_square = state.en_passant_square;
self.halfmove_clock = Some(state.halfmove_clock);
fn add_zobrist_hash(&mut self, zobrist_hash: &ZobristHash) {
self.zobrist_hash = Some(zobrist_hash.to_owned())
}
}

View File

@@ -6,3 +6,4 @@ pub mod history;
pub mod mailbox;
pub mod square;
pub mod state;
pub mod zobrist;

View File

@@ -78,6 +78,10 @@ pub const fn coords_to_square(rank: usize, file: usize) -> usize {
rank * 8 + file
}
pub const fn square_to_file(square: usize) -> usize {
square % 8
}
pub fn square_to_algebraic(square: usize) -> String {
let file = (square % 8) as u8;
let rank = (square / 8) as u8;

View File

@@ -3,7 +3,7 @@ use crate::{
movegen::r#move::{Move, MoveType},
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct State {
side_to_move: Color,
pub castling_ability: [Castle; 2],
@@ -158,3 +158,9 @@ pub enum Castle {
Both,
None,
}
impl Castle {
pub const fn idx(self) -> usize {
self as usize
}
}

248
src/board/zobrist.rs Normal file
View File

@@ -0,0 +1,248 @@
use crate::movegen::r#move::Promote;
use super::{
bitboard::lsb,
board::{self, Board, Color, PieceType},
square::{self, square_to_file, Square},
state::Castle,
};
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use std::sync::LazyLock;
static ZOBRIST_KEYS: LazyLock<ZobristKeys> = LazyLock::new(ZobristKeys::new);
pub fn zobrist_keys() -> &'static ZobristKeys {
&ZOBRIST_KEYS
}
pub struct ZobristKeys {
piece_square_color: [[[u64; 2]; 6]; 64],
en_passant: [u64; 8],
castling_ability: [[u64; 2]; 4],
side_to_move: u64,
}
impl ZobristKeys {
pub fn new() -> Self {
let mut keys = ZobristKeys::default();
let mut state = SmallRng::seed_from_u64(1804289383);
for square in Square::A1..=Square::H8 {
for piece_idx in 0..6 {
keys.piece_square_color[square][piece_idx][0] = state.next_u64();
keys.piece_square_color[square][piece_idx][1] = state.next_u64();
}
}
for file in 0..8 {
keys.en_passant[file] = state.next_u64();
}
for rights in 0..4 {
keys.castling_ability[rights][0] = state.next_u64();
keys.castling_ability[rights][1] = state.next_u64();
}
keys.side_to_move = state.next_u64();
Self {
piece_square_color: keys.piece_square_color,
en_passant: keys.en_passant,
castling_ability: keys.castling_ability,
side_to_move: keys.side_to_move,
}
}
pub fn calculate_hash(&self, board: &Board) -> ZobristHash {
let mut hash = 0;
let white_pieces = &board.white_pieces;
let black_pieces = &board.black_pieces;
for (idx, piece) in white_pieces.iter().enumerate() {
let mut bb = piece.bitboard;
while bb != 0 {
let square = lsb(bb);
hash ^= self.piece_square_color[square][idx][0];
bb &= bb - 1;
}
}
for (idx, piece) in black_pieces.iter().enumerate() {
let mut bb = piece.bitboard;
while bb != 0 {
let square = lsb(bb);
hash ^= self.piece_square_color[square][idx][1];
bb &= bb - 1;
}
}
if board.state.current_player().eq(&board::Color::Black) {
hash ^= self.side_to_move
}
hash ^= self.castling_ability[board.state.castling_ability[0].idx()][0];
hash ^= self.castling_ability[board.state.castling_ability[1].idx()][1];
if let Some(ep) = board.state.en_passant_square {
hash ^= self.en_passant[square_to_file(ep)];
}
ZobristHash::new(hash)
}
}
impl Default for ZobristKeys {
fn default() -> Self {
Self {
piece_square_color: [[[0; 2]; 6]; 64],
en_passant: [0; 8],
castling_ability: [[0; 2]; 4],
side_to_move: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub struct ZobristHash {
hash: u64,
}
impl ZobristHash {
pub const fn new(hash: u64) -> Self {
Self { hash }
}
pub fn update_side_to_move_key(&mut self) {
let keys = zobrist_keys();
self.hash ^= keys.side_to_move
}
pub fn update_en_passant_keys(&mut self, old_en_passant: usize) {
let keys = zobrist_keys();
self.hash ^= keys.en_passant[square::square_to_file(old_en_passant)]
}
pub fn update_castling_ability_keys(
&mut self,
old_castling_ability: [Castle; 2],
new_castling_ability: [Castle; 2],
) {
let keys = zobrist_keys();
self.hash ^= keys.castling_ability[old_castling_ability[0].idx()][0];
self.hash ^= keys.castling_ability[old_castling_ability[1].idx()][1];
self.hash ^= keys.castling_ability[new_castling_ability[0].idx()][0];
self.hash ^= keys.castling_ability[new_castling_ability[1].idx()][1]
}
pub fn update_quiet(&mut self, src: usize, dst: usize, piece_at_src: PieceType, color: Color) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
}
pub fn update_capture(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_dst: PieceType,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_dst.idx()][color.opponent().idx()]
}
pub fn update_en_passant(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_capture: PieceType,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_capture.idx()][color.opponent().idx()]
}
pub fn update_double_push(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
color: Color,
new_en_passant_target: Option<usize>,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
if let Some(new_en_passant) = new_en_passant_target {
self.hash ^= keys.en_passant[square::square_to_file(new_en_passant)];
}
}
pub fn update_promotion(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
promote: &Promote,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][promote.into_piece_type().idx()][color.idx()];
}
pub fn update_promotion_capture(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_dst: PieceType,
promote: &Promote,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_dst.idx()][color.opponent().idx()];
self.hash ^= keys.piece_square_color[dst][promote.into_piece_type().idx()][color.idx()];
}
pub fn update_castle(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
rook_src: usize,
rook_dst: usize,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[rook_src][PieceType::Rook.idx()][color.idx()];
self.hash ^= keys.piece_square_color[rook_dst][PieceType::Rook.idx()][color.idx()];
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_calculate_hash() -> Result<(), String> {
Ok(())
}
#[test]
fn test_update_hash() -> Result<(), String> {
//TODO: how to test
// test if an incremental position is the same as if it would be calculated from scratch
Ok(())
}
}