trying to fix this shit :(

This commit is contained in:
nora 2022-01-30 18:18:11 +01:00
parent 36128c22c8
commit feac668bb8
2 changed files with 45 additions and 25 deletions

View file

@ -6,5 +6,5 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
bumpalo = { version = "3.9.1", features = [] } bumpalo = { version = "3.9.1" }
rand = "0.8.4" rand = "0.8.4"

View file

@ -11,35 +11,44 @@ pub enum Status<P> {
Winner(P), Winner(P),
} }
pub trait PlayerState: Eq + Copy {
fn next(self) -> Self;
}
pub trait GameState: Clone + std::fmt::Debug { pub trait GameState: Clone + std::fmt::Debug {
type Player: Eq + Copy + std::fmt::Debug; type Player: PlayerState + std::fmt::Debug;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>; fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn status(&self) -> Status<Self::Player>; fn status(&self) -> Status<Self::Player>;
fn toggle_player(&mut self);
fn next_random_play(&mut self); fn next_random_play(&mut self);
} }
mod mcts { mod mcts {
use crate::{GameState, Status}; use crate::{GameState, PlayerState, Status};
use bumpalo::Bump; use bumpalo::Bump;
use rand::Rng; use rand::Rng;
use std::cell::Cell; use std::cell::Cell;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct Node<'tree, S> { struct Node<'tree, S: GameState> {
state: S, state: S,
// todo: don't have this field and let the GameState handle this all
player: S::Player,
visited: Cell<u32>, visited: Cell<u32>,
score: Cell<i32>, score: Cell<i32>,
parent: Option<&'tree Node<'tree, S>>, parent: Option<&'tree Node<'tree, S>>,
children: Cell<&'tree [Node<'tree, S>]>, children: Cell<&'tree [Node<'tree, S>]>,
} }
impl<'tree, S> Node<'tree, S> { impl<'tree, S: GameState> Node<'tree, S> {
fn new(state: S, alloc: &'tree Bump) -> Node<S> { fn new(state: S, player: S::Player, alloc: &'tree Bump) -> Node<'tree, S> {
Node { Node {
state, state,
player,
visited: Cell::new(0), visited: Cell::new(0),
score: Cell::new(0), score: Cell::new(0),
parent: None, parent: None,
@ -64,26 +73,31 @@ mod mcts {
const MAX_TRIES: u64 = 10_000; const MAX_TRIES: u64 = 10_000;
pub fn find_next_move<S: GameState>(current_state: S, opponent: S::Player) -> S { pub fn find_next_move<S: GameState>(current_state: S, own_player: S::Player) -> S {
let alloc = Bump::new(); let alloc = Bump::new();
let opponent = own_player.next();
let root_node = alloc.alloc(Node::new(current_state, &alloc)); let root_node = alloc.alloc(Node::new(current_state, opponent, &alloc));
for _ in 0..MAX_TRIES { for _ in 0..MAX_TRIES {
// Phase 1 - Selection
let promising_node = select_promising_node(root_node); let promising_node = select_promising_node(root_node);
// Phase 2 - Expansion
if promising_node.state.status() == Status::InProgress { if promising_node.state.status() == Status::InProgress {
expand_node(&alloc, promising_node); expand_node(&alloc, promising_node);
} }
if !promising_node.children.get().is_empty() { // Phase 3 - Simulation
let child = promising_node.random_child(); let promising_node = if !promising_node.children.get().is_empty() {
let playout_result = simulate_random_playout(child, opponent); promising_node.random_child()
back_propagation(child, playout_result);
} else { } else {
let playout_result = simulate_random_playout(promising_node, opponent); promising_node
back_propagation(promising_node, playout_result);
}; };
let playout_result = simulate_random_playout(promising_node, opponent);
// Phase 4 - Update
back_propagation(promising_node, playout_result);
} }
let winner_node = root_node.child_with_max_score(); let winner_node = root_node.child_with_max_score();
@ -92,7 +106,9 @@ mod mcts {
node.state.clone() node.state.clone()
} }
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> { fn select_promising_node<'tree, S: GameState>(
root_node: &'tree Node<'tree, S>,
) -> &'tree Node<'tree, S> {
let mut node = root_node; let mut node = root_node;
while !node.children.get().is_empty() { while !node.children.get().is_empty() {
@ -107,6 +123,7 @@ mod mcts {
let new_nodes = possible_states.map(|state| Node { let new_nodes = possible_states.map(|state| Node {
state, state,
player: node.player.next(),
visited: Cell::new(0), visited: Cell::new(0),
score: Cell::new(0), score: Cell::new(0),
parent: Some(node), parent: Some(node),
@ -148,6 +165,7 @@ mod mcts {
} }
while board_status == Status::InProgress { while board_status == Status::InProgress {
state.toggle_player();
state.next_random_play(); state.next_random_play();
board_status = state.status(); board_status = state.status();
} }
@ -157,6 +175,7 @@ mod mcts {
mod uct { mod uct {
use crate::mcts::Node; use crate::mcts::Node;
use crate::GameState;
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 { pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
if node_visit == 0 { if node_visit == 0 {
@ -170,7 +189,7 @@ mod mcts {
num as u32 num as u32
} }
pub(super) fn find_best_node_with_uct<'tree, S>( pub(super) fn find_best_node_with_uct<'tree, S: GameState>(
node: &'tree Node<'tree, S>, node: &'tree Node<'tree, S>,
) -> Option<&'tree Node<'tree, S>> { ) -> Option<&'tree Node<'tree, S>> {
let parent_visit_count = node.visited.get(); let parent_visit_count = node.visited.get();
@ -184,7 +203,7 @@ mod mcts {
} }
pub mod tic_tac_toe { pub mod tic_tac_toe {
use crate::{GameState, Status}; use crate::{GameState, PlayerState, Status};
use rand::Rng; use rand::Rng;
use std::fmt::{Display, Formatter, Write}; use std::fmt::{Display, Formatter, Write};
@ -194,10 +213,8 @@ pub mod tic_tac_toe {
X, X,
} }
impl std::ops::Not for Player { impl crate::PlayerState for Player {
type Output = Self; fn next(self) -> Self {
fn not(self) -> Self::Output {
match self { match self {
Self::O => Self::X, Self::O => Self::X,
Self::X => Self::O, Self::X => Self::O,
@ -265,7 +282,7 @@ pub mod tic_tac_toe {
.map(|(i, _)| { .map(|(i, _)| {
let mut new_state = *self; let mut new_state = *self;
new_state.active_player = !self.active_player; new_state.active_player = self.active_player.next();
new_state.board[i] = new_state.active_player.into(); new_state.board[i] = new_state.active_player.into();
new_state new_state
@ -306,9 +323,11 @@ pub mod tic_tac_toe {
Status::InProgress Status::InProgress
} }
fn next_random_play(&mut self) { fn toggle_player(&mut self) {
self.active_player = !self.active_player; self.active_player = self.active_player.next();
}
fn next_random_play(&mut self) {
let free_fields = self.free_fields(); let free_fields = self.free_fields();
let random_field = rand::thread_rng().gen_range(0..free_fields); let random_field = rand::thread_rng().gen_range(0..free_fields);
@ -352,6 +371,7 @@ pub mod tic_tac_toe {
use std::io::Write; use std::io::Write;
const PLAYING_PLAYER: Player = Player::O; const PLAYING_PLAYER: Player = Player::O;
const AI_PLAYER: Player = Player::X;
pub fn main() { pub fn main() {
let mut board = Board::new(PLAYING_PLAYER); let mut board = Board::new(PLAYING_PLAYER);
@ -364,7 +384,7 @@ pub mod tic_tac_toe {
break result; break result;
} }
let ai_play = mcts::find_next_move(board, PLAYING_PLAYER); let ai_play = mcts::find_next_move(board, AI_PLAYER);
board = ai_play; board = ai_play;
if let Some(result) = is_finished(&board) { if let Some(result) = is_finished(&board) {