diff --git a/minmax-rs/src/lib.rs b/minmax-rs/src/lib.rs index a3aff71..8c5f868 100644 --- a/minmax-rs/src/lib.rs +++ b/minmax-rs/src/lib.rs @@ -12,7 +12,10 @@ pub mod tic_tac_toe; pub mod player; -use std::{fmt::Display, ops::Neg}; +use std::{ + fmt::{Debug, Display}, + ops::Neg, +}; pub use self::minmax::PerfectPlayer; pub use player::{Player, State}; @@ -78,7 +81,7 @@ pub trait Game: Display { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct Score(i32); impl Score { @@ -106,3 +109,35 @@ impl Neg for Score { Self(-self.0) } } + +impl Debug for Score { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + Self::WON => f.write_str("WON"), + Self::LOST => f.write_str("LOST"), + Self(other) => Debug::fmt(&other, f), + } + } +} + +#[cfg(test)] +fn assert_win_ratio, O: GamePlayer>( + runs: u64, + x_win_ratio: f64, + x: impl Fn() -> X, + o: impl Fn() -> O, +) { + let mut results = [0u64, 0, 0]; + + for _ in 0..runs { + let result = G::empty().play::(&mut x(), &mut o()); + let idx = Player::as_u8(result); + results[idx as usize] += 1; + } + + let total = results.iter().copied().sum::(); + + let ratio = (total as f64) / (results[0] as f64); + println!("{ratio} >= {x_win_ratio}"); + assert!(ratio >= x_win_ratio); +} diff --git a/minmax-rs/src/main.rs b/minmax-rs/src/main.rs index 8d6004a..2213da6 100644 --- a/minmax-rs/src/main.rs +++ b/minmax-rs/src/main.rs @@ -87,8 +87,8 @@ fn main() { } }; - let player_a = get_player(args.o); - let player_b = get_player(args.x); + let player_a = get_player(args.x); + let player_b = get_player(args.o); play_with_players(player_a, player_b); } @@ -105,8 +105,8 @@ fn main() { } }; - let player_a = get_player(args.o); - let player_b = get_player(args.x); + let player_a = get_player(args.x); + let player_b = get_player(args.o); play_with_players(player_a, player_b); } diff --git a/minmax-rs/src/minmax.rs b/minmax-rs/src/minmax.rs index 806a5ee..72a928d 100644 --- a/minmax-rs/src/minmax.rs +++ b/minmax-rs/src/minmax.rs @@ -61,31 +61,37 @@ impl PerfectPlayer { for pos in board.possible_moves() { board.make_move(pos, maximizing_player); - let value = - -self.minmax(board, maximizing_player.opponent(), -beta, -max_value, depth + 1); + let value = -self.minmax( + board, + maximizing_player.opponent(), + -beta, + -max_value, + depth + 1, + ); board.undo_move(pos); - if value > max_value { + if value >= max_value { max_value = value; if depth == 0 { self.best_move = Some(pos); } - } - // Imagine a game tree like this - // P( ) - // / \ - // A(10) B( ) <- we are here in the loop for the first child that returned 11. - // / \ - // C(11) D( ) - // - // Our beta parameter is 10, because that's the current max_value of our parent. - // If P plays B, we know that B will pick something _at least_ as good as C. This means - // that B will be -11 or worse. -11 is definitly worse than -10, so playing B is definitly - // a very bad idea, no matter the value of D. So don't even bother calculating the value of D - // and just break out. - if max_value >= beta { - break; + + // Imagine a game tree like this + // P( ) + // / \ + // A(10) B( ) <- we are here in the loop for the first child that returned 11. + // / \ + // C(11) D( ) + // + // Our beta parameter is 10, because that's the current max_value of our parent. + // If P plays B, we know that B will pick something _at least_ as good as C. This means + // that B will be -11 or worse. -11 is definitly worse than -10, so playing B is definitly + // a very bad idea, no matter the value of D. So don't even bother calculating the value of D + // and just break out. + if max_value >= beta { + break; + } } } @@ -109,3 +115,56 @@ impl GamePlayer for PerfectPlayer { } } } + +#[cfg(test)] +mod tests { + use crate::assert_win_ratio; + use crate::connect4::board::Connect4; + use crate::minmax::PerfectPlayer; + + use crate::player::{GreedyPlayer, RandomPlayer}; + use crate::tic_tac_toe::TicTacToe; + + #[test] + fn perfect_always_beats_greedy() { + assert_win_ratio::(1, 1.0, || PerfectPlayer::new(false), || GreedyPlayer); + assert_win_ratio::( + 1, + 1.0, + || PerfectPlayer::new(false).with_max_depth(Some(8)), + || GreedyPlayer, + ); + } + + #[test] + fn perfect_beats_random() { + assert_win_ratio::( + 10, + 0.95, + || PerfectPlayer::new(false), + || RandomPlayer, + ); + assert_win_ratio::( + 5, + 0.95, + || PerfectPlayer::new(false).with_max_depth(Some(7)), + || RandomPlayer, + ); + } + + #[test] + fn good_beat_bad() { + assert_win_ratio::( + 1, + 1.0, + || PerfectPlayer::new(false).with_max_depth(Some(7)), + || PerfectPlayer::new(false).with_max_depth(Some(5)), + ); + assert_win_ratio::( + 1, + 1.0, + || PerfectPlayer::new(false).with_max_depth(Some(7)), + || PerfectPlayer::new(false).with_max_depth(Some(5)), + ); + } +} diff --git a/minmax-rs/src/tic_tac_toe/board.rs b/minmax-rs/src/tic_tac_toe/board.rs index 8857494..aaa98ba 100644 --- a/minmax-rs/src/tic_tac_toe/board.rs +++ b/minmax-rs/src/tic_tac_toe/board.rs @@ -138,8 +138,13 @@ impl Game for TicTacToe { TicTacToe::result(self) } - fn rate(&self, _: Player) -> Score { - unimplemented!("we always finish the board") + fn rate(&self, player: Player) -> Score { + match self.result() { + State::Winner(winner) if player == winner => Score::WON, + State::Winner(_) => Score::LOST, + State::InProgress => Score::TIE, + State::Draw => Score::TIE, + } } fn make_move(&mut self, position: Self::Move, player: Player) { diff --git a/minmax-rs/src/tic_tac_toe/mod.rs b/minmax-rs/src/tic_tac_toe/mod.rs index 4af6e6d..9fbdbb0 100644 --- a/minmax-rs/src/tic_tac_toe/mod.rs +++ b/minmax-rs/src/tic_tac_toe/mod.rs @@ -3,41 +3,3 @@ mod game; mod player; pub use {board::TicTacToe, player::*}; - -#[cfg(test)] -mod tests { - use crate::{minmax::PerfectPlayer, tic_tac_toe::board::TicTacToe, GamePlayer, Player}; - - use crate::player::{GreedyPlayer, RandomPlayer}; - - fn assert_win_ratio, O: GamePlayer>( - runs: u64, - x_win_ratio: f64, - x: impl Fn() -> X, - o: impl Fn() -> O, - ) { - let mut results = [0u64, 0, 0]; - - for _ in 0..runs { - let result = TicTacToe::empty().play::(&mut x(), &mut o()); - let idx = Player::as_u8(result); - results[idx as usize] += 1; - } - - let total = results.iter().copied().sum::(); - - let ratio = (total as f64) / (results[0] as f64); - println!("{ratio} >= {x_win_ratio}"); - assert!(ratio >= x_win_ratio); - } - - #[test] - fn perfect_always_beats_greedy() { - assert_win_ratio(1, 1.0, || PerfectPlayer::new(false), || GreedyPlayer); - } - - #[test] - fn perfect_beats_random() { - assert_win_ratio(10, 0.95, || PerfectPlayer::new(false), || RandomPlayer); - } -}