trying to fix this shit :(

This commit is contained in:
nora 2022-01-30 18:18:11 +01:00
parent 36128c22c8
commit feac668bb8
2 changed files with 45 additions and 25 deletions

View file

@ -6,5 +6,5 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bumpalo = { version = "3.9.1", features = [] }
bumpalo = { version = "3.9.1" }
rand = "0.8.4"

View file

@ -11,35 +11,44 @@ pub enum Status<P> {
Winner(P),
}
pub trait PlayerState: Eq + Copy {
fn next(self) -> Self;
}
pub trait GameState: Clone + std::fmt::Debug {
type Player: Eq + Copy + std::fmt::Debug;
type Player: PlayerState + std::fmt::Debug;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn status(&self) -> Status<Self::Player>;
fn toggle_player(&mut self);
fn next_random_play(&mut self);
}
mod mcts {
use crate::{GameState, Status};
use crate::{GameState, PlayerState, Status};
use bumpalo::Bump;
use rand::Rng;
use std::cell::Cell;
#[derive(Debug, Clone)]
struct Node<'tree, S> {
struct Node<'tree, S: GameState> {
state: S,
// todo: don't have this field and let the GameState handle this all
player: S::Player,
visited: Cell<u32>,
score: Cell<i32>,
parent: Option<&'tree Node<'tree, S>>,
children: Cell<&'tree [Node<'tree, S>]>,
}
impl<'tree, S> Node<'tree, S> {
fn new(state: S, alloc: &'tree Bump) -> Node<S> {
impl<'tree, S: GameState> Node<'tree, S> {
fn new(state: S, player: S::Player, alloc: &'tree Bump) -> Node<'tree, S> {
Node {
state,
player,
visited: Cell::new(0),
score: Cell::new(0),
parent: None,
@ -64,26 +73,31 @@ mod mcts {
const MAX_TRIES: u64 = 10_000;
pub fn find_next_move<S: GameState>(current_state: S, opponent: S::Player) -> S {
pub fn find_next_move<S: GameState>(current_state: S, own_player: S::Player) -> S {
let alloc = Bump::new();
let opponent = own_player.next();
let root_node = alloc.alloc(Node::new(current_state, &alloc));
let root_node = alloc.alloc(Node::new(current_state, opponent, &alloc));
for _ in 0..MAX_TRIES {
// Phase 1 - Selection
let promising_node = select_promising_node(root_node);
// Phase 2 - Expansion
if promising_node.state.status() == Status::InProgress {
expand_node(&alloc, promising_node);
}
if !promising_node.children.get().is_empty() {
let child = promising_node.random_child();
let playout_result = simulate_random_playout(child, opponent);
back_propagation(child, playout_result);
// Phase 3 - Simulation
let promising_node = if !promising_node.children.get().is_empty() {
promising_node.random_child()
} else {
let playout_result = simulate_random_playout(promising_node, opponent);
back_propagation(promising_node, playout_result);
promising_node
};
let playout_result = simulate_random_playout(promising_node, opponent);
// Phase 4 - Update
back_propagation(promising_node, playout_result);
}
let winner_node = root_node.child_with_max_score();
@ -92,7 +106,9 @@ mod mcts {
node.state.clone()
}
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
fn select_promising_node<'tree, S: GameState>(
root_node: &'tree Node<'tree, S>,
) -> &'tree Node<'tree, S> {
let mut node = root_node;
while !node.children.get().is_empty() {
@ -107,6 +123,7 @@ mod mcts {
let new_nodes = possible_states.map(|state| Node {
state,
player: node.player.next(),
visited: Cell::new(0),
score: Cell::new(0),
parent: Some(node),
@ -148,6 +165,7 @@ mod mcts {
}
while board_status == Status::InProgress {
state.toggle_player();
state.next_random_play();
board_status = state.status();
}
@ -157,6 +175,7 @@ mod mcts {
mod uct {
use crate::mcts::Node;
use crate::GameState;
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
if node_visit == 0 {
@ -170,7 +189,7 @@ mod mcts {
num as u32
}
pub(super) fn find_best_node_with_uct<'tree, S>(
pub(super) fn find_best_node_with_uct<'tree, S: GameState>(
node: &'tree Node<'tree, S>,
) -> Option<&'tree Node<'tree, S>> {
let parent_visit_count = node.visited.get();
@ -184,7 +203,7 @@ mod mcts {
}
pub mod tic_tac_toe {
use crate::{GameState, Status};
use crate::{GameState, PlayerState, Status};
use rand::Rng;
use std::fmt::{Display, Formatter, Write};
@ -194,10 +213,8 @@ pub mod tic_tac_toe {
X,
}
impl std::ops::Not for Player {
type Output = Self;
fn not(self) -> Self::Output {
impl crate::PlayerState for Player {
fn next(self) -> Self {
match self {
Self::O => Self::X,
Self::X => Self::O,
@ -265,7 +282,7 @@ pub mod tic_tac_toe {
.map(|(i, _)| {
let mut new_state = *self;
new_state.active_player = !self.active_player;
new_state.active_player = self.active_player.next();
new_state.board[i] = new_state.active_player.into();
new_state
@ -306,9 +323,11 @@ pub mod tic_tac_toe {
Status::InProgress
}
fn next_random_play(&mut self) {
self.active_player = !self.active_player;
fn toggle_player(&mut self) {
self.active_player = self.active_player.next();
}
fn next_random_play(&mut self) {
let free_fields = self.free_fields();
let random_field = rand::thread_rng().gen_range(0..free_fields);
@ -352,6 +371,7 @@ pub mod tic_tac_toe {
use std::io::Write;
const PLAYING_PLAYER: Player = Player::O;
const AI_PLAYER: Player = Player::X;
pub fn main() {
let mut board = Board::new(PLAYING_PLAYER);
@ -364,7 +384,7 @@ pub mod tic_tac_toe {
break result;
}
let ai_play = mcts::find_next_move(board, PLAYING_PLAYER);
let ai_play = mcts::find_next_move(board, AI_PLAYER);
board = ai_play;
if let Some(result) = is_finished(&board) {