Add perft function and fix move generator bugs using perftree

This commit is contained in:
2024-07-01 23:29:56 +03:00
parent 941c14c199
commit 2a1a3224cb
7 changed files with 348 additions and 40 deletions

3
perft.sh Executable file
View File

@@ -0,0 +1,3 @@
#!/bin/bash
./target/release/ippos perft

View File

@@ -173,8 +173,8 @@ pub fn bishop_attacks_on_the_fly(bitboard: Bitboard, blocker: Bitboard) -> Bitbo
}
}
(rank, file) = (target_rank.saturating_sub(1), target_file + 1);
for (rank, file) in (0..=rank).rev().zip(file..=7) {
file = target_file + 1;
for (rank, file) in (0..target_rank).rev().zip(file..=7) {
let attack_square = 1_u64 << (rank * 8 + file);
attacks |= attack_square;
if attack_square & blocker != 0 {
@@ -182,8 +182,8 @@ pub fn bishop_attacks_on_the_fly(bitboard: Bitboard, blocker: Bitboard) -> Bitbo
}
}
(rank, file) = (target_rank + 1, target_file.saturating_sub(1));
for (rank, file) in (rank..=7).zip((0..=file).rev()) {
rank = target_rank + 1;
for (rank, file) in (rank..=7).zip((0..target_file).rev()) {
let attack_square = 1_u64 << (rank * 8 + file);
attacks |= attack_square;
if attack_square & blocker != 0 {
@@ -191,8 +191,7 @@ pub fn bishop_attacks_on_the_fly(bitboard: Bitboard, blocker: Bitboard) -> Bitbo
}
}
(rank, file) = (target_rank.saturating_sub(1), target_file.saturating_sub(1));
for (rank, file) in (0..=rank).rev().zip((0..=file).rev()) {
for (rank, file) in (0..target_rank).rev().zip((0..target_file).rev()) {
let attack_square = 1_u64 << (rank * 8 + file);
attacks |= attack_square;
if attack_square & blocker != 0 {
@@ -233,7 +232,6 @@ pub fn rook_attacks_on_the_fly(bitboard: Bitboard, blocker: Bitboard) -> Bitboar
}
(rank, file) = (target_rank.saturating_sub(1), target_file.saturating_sub(1));
for rank in (0..=rank).rev() {
let attack_square = 1_u64 << (rank * 8 + target_file);
attacks |= attack_square;
@@ -257,9 +255,9 @@ pub fn set_occupancy(index: u64, bits_in_mask: u8, mut attack_mask: Bitboard) ->
for count in 0..bits_in_mask {
let square = attack_mask.trailing_zeros() as u8;
attack_mask &= !1_u64 << (square);
if index & (1_u64 << (count)) != 0 {
occupancy |= 1_u64 << (square);
attack_mask &= !(1_u64 << square);
if index & (1_u64 << count) != 0 {
occupancy |= 1_u64 << square;
}
}
@@ -332,11 +330,10 @@ pub fn init_bishop_attacks() {
for sq in 0..64 {
bishop_masks[sq] = mask_bishop_attacks(1_u64 << (sq));
let attack_mask = bishop_masks[sq];
let relevant_bits_count = attack_mask.count_ones() as u8;
let occupancy_indices = 1_u64 << (relevant_bits_count);
let occupancy_indices = 1_u64 << BISHOP_RELEVANT_BITS[sq];
for idx in 0..occupancy_indices {
let occupancy = set_occupancy(idx, relevant_bits_count, attack_mask);
let occupancy = set_occupancy(idx, BISHOP_RELEVANT_BITS[sq], attack_mask);
unsafe {
let magic_index =
occupancy.wrapping_mul(BISHOP_MAGIC[sq]) >> (64 - BISHOP_RELEVANT_BITS[sq]);
@@ -470,6 +467,11 @@ mod tests {
let attacks = bishop_attacks_on_the_fly(bishop_d4.bitboard, blocker_c5);
assert_eq!(attacks, 0x8040201400142241);
let bishop_a1 = Piece::new(0x0, Kind::Bishop, Color::White);
let blocker_none = 0x0_u64;
let attacks = bishop_attacks_on_the_fly(bishop_a1.bitboard, blocker_none);
assert_eq!(attacks, 0x8040201008040200);
Ok(())
}

View File

@@ -115,7 +115,6 @@ impl Board {
moves.extend(self.pseudo_moves(color, Kind::King));
moves
}
pub fn pseudo_moves(&self, color: Color, kind: Kind) -> Vec<Move> {
let all_occupancies = self.get_all_occupancies();
let (pieces, enemy_occupancies, own_occupancies) = match color {
@@ -147,20 +146,18 @@ impl Board {
}
}
pub fn make_move_and_reset(&mut self, mv: Move, color: Color) -> Vec<Board> {
let mut board_variants = vec![];
pub fn make_move_and_reset(&mut self, mv: Move, color: Color) -> bool {
let original_board = self.clone();
if self.make_move(mv, color) {
board_variants.push(self.clone());
}
let result = self.make_move(mv, color);
*self = original_board;
board_variants
result
}
pub fn make_move(&mut self, mv: Move, color: Color) -> bool {
self.update_board_state(&mv, color);
self.state.update_castling_state(mv.source as u8, color);
self.state.next_turn();
self.state.change_side();
let pieces = match color {
Color::White => &self.white_pieces,
@@ -180,10 +177,12 @@ impl Board {
match &mv.move_type {
MoveType::Quiet => {
Board::move_piece(mv.source as u8, mv.target as u8, own_pieces);
self.state.set_en_passant_target_square(None);
}
MoveType::Capture => {
Board::move_piece(mv.source as u8, mv.target as u8, own_pieces);
Board::remove_piece(mv.target as u8, opponent_pieces);
self.state.set_en_passant_target_square(None);
}
MoveType::EnPassant => {
Board::move_piece(mv.source as u8, mv.target as u8, own_pieces);
@@ -198,15 +197,18 @@ impl Board {
MoveType::Promotion(promote) => {
Board::remove_piece(mv.source as u8, own_pieces);
Board::promote_piece(mv.target as u8, own_pieces, *promote);
self.state.set_en_passant_target_square(None);
}
MoveType::PromotionCapture(promote) => {
Board::remove_piece(mv.source as u8, own_pieces);
Board::remove_piece(mv.target as u8, opponent_pieces);
Board::promote_piece(mv.target as u8, own_pieces, *promote);
self.state.set_en_passant_target_square(None);
}
MoveType::Castle => {
Board::move_piece(mv.source as u8, mv.target as u8, own_pieces);
Board::move_rook_castle(mv.target as u8, own_pieces);
self.state.set_en_passant_target_square(None);
}
}
}
@@ -465,4 +467,25 @@ mod tests {
Ok(())
}
#[test]
fn test_perftree_snapshots() -> Result<(), String> {
init_attacks();
let mut game = from_fen("rnbqkbnr/ppppp1pp/5p2/7Q/8/4P3/PPPP1PPP/RNB1KBNR w KQkq - 0 1")?;
for mv in game.board.pseudo_moves_all(Color::Black) {
if game.board.make_move_and_reset(mv, Color::Black) {
println!("{:?}", mv)
}
}
let mut game = from_fen("rnbqkbnr/1ppppppp/p7/8/Q7/2P5/PP1PPPPP/RNB1KBNR w KQkq - 0 1")?;
for mv in game.board.pseudo_moves_all(Color::Black) {
if game.board.make_move_and_reset(mv, Color::Black) {
println!("{:?}", mv)
}
}
Ok(())
}
}

View File

@@ -100,13 +100,17 @@ impl State {
}
}
pub fn next_turn(&mut self) {
pub fn change_side(&mut self) {
self.side_to_move = match self.side_to_move {
Color::White => Color::Black,
Color::Black => Color::White,
}
}
pub fn next_turn(&self) -> Color {
self.side_to_move
}
pub fn set_en_passant_target_square(&mut self, sq: Option<u8>) {
self.en_passant_target_square = sq;
}

View File

@@ -5,10 +5,13 @@ pub mod game;
pub mod magic;
pub mod r#move;
pub mod movegen;
use game::Game;
pub mod perft;
fn main() {
let game = Game::new();
game.run();
attack::init_attacks();
let args: Vec<String> = std::env::args().collect();
if args.len() == 2 && args[1] == "perft" {
perft::perftree_script()
}
}

View File

@@ -151,11 +151,13 @@ fn white_pawn_capture_moves(
if let Some(en_passant_square) = en_passant_square {
let attacked_from = get_pawn_attacks(en_passant_square as usize, Color::Black);
let result = attacked_from & pawns;
moves.push(Move::new_with_type(
result.trailing_zeros(),
en_passant_square as u32,
MoveType::EnPassant,
));
if result != 0 {
moves.push(Move::new_with_type(
result.trailing_zeros(),
en_passant_square as u32,
MoveType::EnPassant,
));
}
};
moves
}
@@ -204,11 +206,13 @@ fn black_pawn_capture_moves(
if let Some(en_passant_square) = en_passant_square {
let attacked_from = get_pawn_attacks(en_passant_square as usize, Color::White);
let result = attacked_from & pawns;
moves.push(Move::new_with_type(
result.trailing_zeros(),
en_passant_square as u32,
MoveType::EnPassant,
));
if result != 0 {
moves.push(Move::new_with_type(
result.trailing_zeros(),
en_passant_square as u32,
MoveType::EnPassant,
));
}
};
moves
}
@@ -351,12 +355,12 @@ pub fn king_pseudo_moves(
fn king_castling_moves(board: &Board, color: Color, all_occupancies: Bitboard) -> Vec<Move> {
let mut moves = vec![];
let (king_from, king_to_short, king_to_long, path_short, path_long) = match color {
Color::White => (4, 2, 6, 0x60_u64, 0xe_u64),
Color::Black => (60, 58, 62, 0x6000000000000000_u64, 0xe00000000000000_u64),
Color::White => (4, 6, 2, 0x60_u64, 0xe_u64),
Color::Black => (60, 62, 58, 0x6000000000000000_u64, 0xe00000000000000_u64),
};
let mut add_move_if_empty_path = |path_mask, king_to| {
if all_occupancies & path_mask != 0 {
if all_occupancies & path_mask == 0 {
moves.push(Move::new_with_type(king_from, king_to, MoveType::Castle))
}
};
@@ -491,7 +495,7 @@ mod tests {
];
let mut actual = new_game.board.pseudo_moves(Color::White, Kind::Bishop);
actual.sort();
assert_eq!(expected, actual);
assert_eq!(expected, actual);
Ok(())
}
@@ -590,4 +594,16 @@ mod tests {
Ok(())
}
#[test]
fn test_random() -> Result<(), String> {
init_attacks();
let game = from_fen("rnbqkbnr/ppppp1pp/5p2/7Q/8/4P3/PPPP1PPP/RNB1KBNR w KQkq - 0 1")?;
for mv in game.board.pseudo_moves_all(Color::Black) {
println!("{:?}", mv);
}
Ok(())
}
}

257
src/perft.rs Normal file
View File

@@ -0,0 +1,257 @@
use crate::game::Game;
pub fn perft_driver(game: &mut Game, nodes: &mut u64, depth: u8) {
if depth == 0 {
*nodes += 1;
return;
}
let pseudo_moves =
game.board.pseudo_moves_all(game.board.state.next_turn());
for mv in pseudo_moves {
let original_board = game.board.clone();
if !game.board.make_move(mv, game.board.state.next_turn()) {
game.board = original_board;
continue;
}
// print_perftree(mv.source, mv.target, depth, nodes);
perft_driver(game, nodes, depth - 1);
game.board = original_board;
}
#[allow(dead_code)]
fn print_perftree(source: u32, target: u32, depth: u8, nodes: &mut u64) {
println!(
"{}{} - depth: {}",
square_to_notation(source as u8),
square_to_notation(target as u8),
depth
);
if depth == MAX_DEPTH {
println!(
"{}{} {}",
square_to_notation(source as u8),
square_to_notation(target as u8),
nodes
);
}
}
}
const MAX_DEPTH: u8 = 4;
pub fn perftree_script() {
let fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
let mut game = crate::fen::from_fen(fen).unwrap();
let (mut nodes, depth): (u64, u8) = (0, MAX_DEPTH);
perft_driver(&mut game, &mut nodes, depth);
println!();
println!("{}", nodes)
}
use std::collections::HashMap;
#[rustfmt::skip]
pub fn square_to_notation(square: u8) -> &'static str {
let mut map: HashMap<u8, &str> = HashMap::new();
map.insert(0, "a1"); map.insert(1, "b1"); map.insert(2, "c1"); map.insert(3, "d1");
map.insert(4, "e1"); map.insert(5, "f1"); map.insert(6, "g1"); map.insert(7, "h1");
map.insert(8, "a2"); map.insert(9, "b2"); map.insert(10, "c2"); map.insert(11, "d2");
map.insert(12, "e2"); map.insert(13, "f2"); map.insert(14, "g2"); map.insert(15, "h2");
map.insert(16, "a3"); map.insert(17, "b3"); map.insert(18, "c3"); map.insert(19, "d3");
map.insert(20, "e3"); map.insert(21, "f3"); map.insert(22, "g3"); map.insert(23, "h3");
map.insert(24, "a4"); map.insert(25, "b4"); map.insert(26, "c4"); map.insert(27, "d4");
map.insert(28, "e4"); map.insert(29, "f4"); map.insert(30, "g4"); map.insert(31, "h4");
map.insert(32, "a5"); map.insert(33, "b5"); map.insert(34, "c5"); map.insert(35, "d5");
map.insert(36, "e5"); map.insert(37, "f5"); map.insert(38, "g5"); map.insert(39, "h5");
map.insert(40, "a6"); map.insert(41, "b6"); map.insert(42, "c6"); map.insert(43, "d6");
map.insert(44, "e6"); map.insert(45, "f6"); map.insert(46, "g6"); map.insert(47, "h6");
map.insert(48, "a7"); map.insert(49, "b7"); map.insert(50, "c7"); map.insert(51, "d7");
map.insert(52, "e7"); map.insert(53, "f7"); map.insert(54, "g7"); map.insert(55, "h7");
map.insert(56, "a8"); map.insert(57, "b8"); map.insert(58, "c8"); map.insert(59, "d8");
map.insert(60, "e8"); map.insert(61, "f8"); map.insert(62, "g8"); map.insert(63, "h8");
map.get(&square).unwrap()
}
#[cfg(test)]
mod tests {
use crate::{attack::init_attacks, fen::from_fen};
use super::perft_driver;
// Examples from https://www.chessprogramming.org/Perft_Results
const FEN_PERFT: [&str; 6] = [
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1",
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1",
"r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1",
"rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8",
"r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10",
];
#[test]
fn test_perft_driver_1() -> Result<(), String> {
init_attacks();
let mut game = from_fen(FEN_PERFT[0])?;
let (mut nodes, depth): (u64, u8) = (0, 1);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes , 20);
let mut game = from_fen(FEN_PERFT[0])?;
let (mut nodes, depth): (u64, u8) = (0, 2);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 400);
let mut game = from_fen(FEN_PERFT[0])?;
let (mut nodes, depth): (u64, u8) = (0, 3);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 8902);
let mut game = from_fen(FEN_PERFT[0])?;
let (mut nodes, depth): (u64, u8) = (0, 4);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 197281);
Ok(())
}
#[test]
fn test_perft_driver_2() -> Result<(), String> {
init_attacks();
let mut game = from_fen(FEN_PERFT[1])?;
let (mut nodes, depth): (u64, u8) = (0, 1);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 48);
let mut game = from_fen(FEN_PERFT[1])?;
let (mut nodes, depth): (u64, u8) = (0, 2);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 2039);
// let mut game = from_fen(FEN_PERFT[1])?;
// let (mut nodes, depth): (u64, u8) = (0, 3);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 97862);
// let mut game = from_fen(FEN_PERFT[1])?;
// let (mut nodes, depth): (u64, u8) = (0, 4);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 4085603);
Ok(())
}
#[test]
fn test_perft_driver_3() -> Result<(), String> {
init_attacks();
// let mut game = from_fen(FEN_PERFT[2])?;
// let (mut nodes, depth): (u64, u8) = (0, 1);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 14);
// let mut game = from_fen(FEN_PERFT[2])?;
// let (mut nodes, depth): (u64, u8) = (0, 2);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 191);
// let mut game = from_fen(FEN_PERFT[2])?;
// let (mut nodes, depth): (u64, u8) = (0, 3);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 2812);
// let mut game = from_fen(FEN_PERFT[2])?;
// let (mut nodes, depth): (u64, u8) = (0, 4);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 43238);
Ok(())
}
#[test]
fn test_perft_driver_4() -> Result<(), String> {
init_attacks();
let mut game = from_fen(FEN_PERFT[3])?;
let (mut nodes, depth): (u64, u8) = (0, 1);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes , 6);
// let mut game = from_fen(FEN_PERFT[3])?;
// let (mut nodes, depth): (u64, u8) = (0, 2);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 264);
// let mut game = from_fen(FEN_PERFT[3])?;
// let (mut nodes, depth): (u64, u8) = (0, 3);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 9467);
// let mut game = from_fen(FEN_PERFT[3])?;
// let (mut nodes, depth): (u64, u8) = (0, 4);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 422333);
Ok(())
}
#[test]
fn test_perft_driver_5() -> Result<(), String> {
init_attacks();
// let mut game = from_fen(FEN_PERFT[4])?;
// let (mut nodes, depth): (u64, u8) = (0, 1);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes , 44);
// let mut game = from_fen(FEN_PERFT[4])?;
// let (mut nodes, depth): (u64, u8) = (0, 2);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 1486);
// let mut game = from_fen(FEN_PERFT[4])?;
// let (mut nodes, depth): (u64, u8) = (0, 3);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 62379);
// let mut game = from_fen(FEN_PERFT[4])?;
// let (mut nodes, depth): (u64, u8) = (0, 4);
// perft_driver(&mut game, &mut nodes, depth);
// assert_eq!(nodes, 2103487);
Ok(())
}
#[test]
fn test_perft_driver_6() -> Result<(), String> {
init_attacks();
let mut game = from_fen(FEN_PERFT[5])?;
let (mut nodes, depth): (u64, u8) = (0, 1);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes , 46);
let mut game = from_fen(FEN_PERFT[5])?;
let (mut nodes, depth): (u64, u8) = (0, 2);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 2079);
let mut game = from_fen(FEN_PERFT[5])?;
let (mut nodes, depth): (u64, u8) = (0, 3);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 89890);
let mut game = from_fen(FEN_PERFT[5])?;
let (mut nodes, depth): (u64, u8) = (0, 4);
perft_driver(&mut game, &mut nodes, depth);
assert_eq!(nodes, 3894594);
Ok(())
}
}