From feac668bb875f995f48b8fd4a0bf1d1da3aec9b9 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Sun, 30 Jan 2022 18:18:11 +0100 Subject: [PATCH] trying to fix this shit :( --- Cargo.toml | 2 +- src/lib.rs | 68 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5bc7823..a7251b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -bumpalo = { version = "3.9.1", features = [] } +bumpalo = { version = "3.9.1" } rand = "0.8.4" diff --git a/src/lib.rs b/src/lib.rs index 9ce4e36..23c3f6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,35 +11,44 @@ pub enum Status

{ Winner(P), } +pub trait PlayerState: Eq + Copy { + fn next(self) -> Self; +} + pub trait GameState: Clone + std::fmt::Debug { - type Player: Eq + Copy + std::fmt::Debug; + type Player: PlayerState + std::fmt::Debug; fn next_states(&self) -> Box>; fn status(&self) -> Status; + fn toggle_player(&mut self); + fn next_random_play(&mut self); } mod mcts { - use crate::{GameState, Status}; + use crate::{GameState, PlayerState, Status}; use bumpalo::Bump; use rand::Rng; use std::cell::Cell; #[derive(Debug, Clone)] - struct Node<'tree, S> { + struct Node<'tree, S: GameState> { state: S, + // todo: don't have this field and let the GameState handle this all + player: S::Player, visited: Cell, score: Cell, parent: Option<&'tree Node<'tree, S>>, children: Cell<&'tree [Node<'tree, S>]>, } - impl<'tree, S> Node<'tree, S> { - fn new(state: S, alloc: &'tree Bump) -> Node { + impl<'tree, S: GameState> Node<'tree, S> { + fn new(state: S, player: S::Player, alloc: &'tree Bump) -> Node<'tree, S> { Node { state, + player, visited: Cell::new(0), score: Cell::new(0), parent: None, @@ -64,26 +73,31 @@ mod mcts { const MAX_TRIES: u64 = 10_000; - pub fn find_next_move(current_state: S, opponent: S::Player) -> S { + pub fn find_next_move(current_state: S, own_player: S::Player) -> S { let alloc = Bump::new(); + let opponent = own_player.next(); - let root_node = alloc.alloc(Node::new(current_state, &alloc)); + let root_node = alloc.alloc(Node::new(current_state, opponent, &alloc)); for _ in 0..MAX_TRIES { + // Phase 1 - Selection let promising_node = select_promising_node(root_node); + // Phase 2 - Expansion if promising_node.state.status() == Status::InProgress { expand_node(&alloc, promising_node); } - if !promising_node.children.get().is_empty() { - let child = promising_node.random_child(); - let playout_result = simulate_random_playout(child, opponent); - back_propagation(child, playout_result); + // Phase 3 - Simulation + let promising_node = if !promising_node.children.get().is_empty() { + promising_node.random_child() } else { - let playout_result = simulate_random_playout(promising_node, opponent); - back_propagation(promising_node, playout_result); + promising_node }; + let playout_result = simulate_random_playout(promising_node, opponent); + + // Phase 4 - Update + back_propagation(promising_node, playout_result); } let winner_node = root_node.child_with_max_score(); @@ -92,7 +106,9 @@ mod mcts { node.state.clone() } - fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> { + fn select_promising_node<'tree, S: GameState>( + root_node: &'tree Node<'tree, S>, + ) -> &'tree Node<'tree, S> { let mut node = root_node; while !node.children.get().is_empty() { @@ -107,6 +123,7 @@ mod mcts { let new_nodes = possible_states.map(|state| Node { state, + player: node.player.next(), visited: Cell::new(0), score: Cell::new(0), parent: Some(node), @@ -148,6 +165,7 @@ mod mcts { } while board_status == Status::InProgress { + state.toggle_player(); state.next_random_play(); board_status = state.status(); } @@ -157,6 +175,7 @@ mod mcts { mod uct { use crate::mcts::Node; + use crate::GameState; pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 { if node_visit == 0 { @@ -170,7 +189,7 @@ mod mcts { num as u32 } - pub(super) fn find_best_node_with_uct<'tree, S>( + pub(super) fn find_best_node_with_uct<'tree, S: GameState>( node: &'tree Node<'tree, S>, ) -> Option<&'tree Node<'tree, S>> { let parent_visit_count = node.visited.get(); @@ -184,7 +203,7 @@ mod mcts { } pub mod tic_tac_toe { - use crate::{GameState, Status}; + use crate::{GameState, PlayerState, Status}; use rand::Rng; use std::fmt::{Display, Formatter, Write}; @@ -194,10 +213,8 @@ pub mod tic_tac_toe { X, } - impl std::ops::Not for Player { - type Output = Self; - - fn not(self) -> Self::Output { + impl crate::PlayerState for Player { + fn next(self) -> Self { match self { Self::O => Self::X, Self::X => Self::O, @@ -265,7 +282,7 @@ pub mod tic_tac_toe { .map(|(i, _)| { let mut new_state = *self; - new_state.active_player = !self.active_player; + new_state.active_player = self.active_player.next(); new_state.board[i] = new_state.active_player.into(); new_state @@ -306,9 +323,11 @@ pub mod tic_tac_toe { Status::InProgress } - fn next_random_play(&mut self) { - self.active_player = !self.active_player; + fn toggle_player(&mut self) { + self.active_player = self.active_player.next(); + } + fn next_random_play(&mut self) { let free_fields = self.free_fields(); let random_field = rand::thread_rng().gen_range(0..free_fields); @@ -352,6 +371,7 @@ pub mod tic_tac_toe { use std::io::Write; const PLAYING_PLAYER: Player = Player::O; + const AI_PLAYER: Player = Player::X; pub fn main() { let mut board = Board::new(PLAYING_PLAYER); @@ -364,7 +384,7 @@ pub mod tic_tac_toe { break result; } - let ai_play = mcts::find_next_move(board, PLAYING_PLAYER); + let ai_play = mcts::find_next_move(board, AI_PLAYER); board = ai_play; if let Some(result) = is_finished(&board) {