Files
zeal/src/board/zobrist.rs
2024-10-06 03:31:33 +03:00

249 lines
7.5 KiB
Rust

use crate::movegen::r#move::Promote;
use super::{
bitboard::lsb,
board::{self, Board, Color, PieceType},
square::{self, square_to_file, Square},
state::Castle,
};
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use std::sync::LazyLock;
static ZOBRIST_KEYS: LazyLock<ZobristKeys> = LazyLock::new(ZobristKeys::new);
pub fn zobrist_keys() -> &'static ZobristKeys {
&ZOBRIST_KEYS
}
pub struct ZobristKeys {
piece_square_color: [[[u64; 2]; 6]; 64],
en_passant: [u64; 8],
castling_ability: [[u64; 2]; 4],
side_to_move: u64,
}
impl ZobristKeys {
pub fn new() -> Self {
let mut keys = ZobristKeys::default();
let mut state = SmallRng::seed_from_u64(1804289383);
for square in Square::A1..=Square::H8 {
for piece_idx in 0..6 {
keys.piece_square_color[square][piece_idx][0] = state.next_u64();
keys.piece_square_color[square][piece_idx][1] = state.next_u64();
}
}
for file in 0..8 {
keys.en_passant[file] = state.next_u64();
}
for rights in 0..4 {
keys.castling_ability[rights][0] = state.next_u64();
keys.castling_ability[rights][1] = state.next_u64();
}
keys.side_to_move = state.next_u64();
Self {
piece_square_color: keys.piece_square_color,
en_passant: keys.en_passant,
castling_ability: keys.castling_ability,
side_to_move: keys.side_to_move,
}
}
pub fn calculate_hash(&self, board: &Board) -> ZobristHash {
let mut hash = 0;
let white_pieces = &board.white_pieces;
let black_pieces = &board.black_pieces;
for (idx, piece) in white_pieces.iter().enumerate() {
let mut bb = piece.bitboard;
while bb != 0 {
let square = lsb(bb);
hash ^= self.piece_square_color[square][idx][0];
bb &= bb - 1;
}
}
for (idx, piece) in black_pieces.iter().enumerate() {
let mut bb = piece.bitboard;
while bb != 0 {
let square = lsb(bb);
hash ^= self.piece_square_color[square][idx][1];
bb &= bb - 1;
}
}
if board.state.current_player().eq(&board::Color::Black) {
hash ^= self.side_to_move
}
hash ^= self.castling_ability[board.state.castling_ability[0].idx()][0];
hash ^= self.castling_ability[board.state.castling_ability[1].idx()][1];
if let Some(ep) = board.state.en_passant_square {
hash ^= self.en_passant[square_to_file(ep)];
}
ZobristHash::new(hash)
}
}
impl Default for ZobristKeys {
fn default() -> Self {
Self {
piece_square_color: [[[0; 2]; 6]; 64],
en_passant: [0; 8],
castling_ability: [[0; 2]; 4],
side_to_move: 0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub struct ZobristHash {
hash: u64,
}
impl ZobristHash {
pub const fn new(hash: u64) -> Self {
Self { hash }
}
pub fn update_side_to_move_key(&mut self) {
let keys = zobrist_keys();
self.hash ^= keys.side_to_move
}
pub fn update_en_passant_keys(&mut self, old_en_passant: usize) {
let keys = zobrist_keys();
self.hash ^= keys.en_passant[square::square_to_file(old_en_passant)]
}
pub fn update_castling_ability_keys(
&mut self,
old_castling_ability: [Castle; 2],
new_castling_ability: [Castle; 2],
) {
let keys = zobrist_keys();
self.hash ^= keys.castling_ability[old_castling_ability[0].idx()][0];
self.hash ^= keys.castling_ability[old_castling_ability[1].idx()][1];
self.hash ^= keys.castling_ability[new_castling_ability[0].idx()][0];
self.hash ^= keys.castling_ability[new_castling_ability[1].idx()][1]
}
pub fn update_quiet(&mut self, src: usize, dst: usize, piece_at_src: PieceType, color: Color) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
}
pub fn update_capture(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_dst: PieceType,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_dst.idx()][color.opponent().idx()]
}
pub fn update_en_passant(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_capture: PieceType,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_capture.idx()][color.opponent().idx()]
}
pub fn update_double_push(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
color: Color,
new_en_passant_target: Option<usize>,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
if let Some(new_en_passant) = new_en_passant_target {
self.hash ^= keys.en_passant[square::square_to_file(new_en_passant)];
}
}
pub fn update_promotion(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
promote: &Promote,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][promote.into_piece_type().idx()][color.idx()];
}
pub fn update_promotion_capture(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
piece_at_dst: PieceType,
promote: &Promote,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_dst.idx()][color.opponent().idx()];
self.hash ^= keys.piece_square_color[dst][promote.into_piece_type().idx()][color.idx()];
}
pub fn update_castle(
&mut self,
src: usize,
dst: usize,
piece_at_src: PieceType,
rook_src: usize,
rook_dst: usize,
color: Color,
) {
let keys = zobrist_keys();
self.hash ^= keys.piece_square_color[src][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[dst][piece_at_src.idx()][color.idx()];
self.hash ^= keys.piece_square_color[rook_src][PieceType::Rook.idx()][color.idx()];
self.hash ^= keys.piece_square_color[rook_dst][PieceType::Rook.idx()][color.idx()];
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_calculate_hash() -> Result<(), String> {
Ok(())
}
#[test]
fn test_update_hash() -> Result<(), String> {
//TODO: how to test
// test if an incremental position is the same as if it would be calculated from scratch
Ok(())
}
}