From feac668bb875f995f48b8fd4a0bf1d1da3aec9b9 Mon Sep 17 00:00:00 2001
From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com>
Date: Sun, 30 Jan 2022 18:18:11 +0100
Subject: [PATCH] trying to fix this shit :(
---
Cargo.toml | 2 +-
src/lib.rs | 68 +++++++++++++++++++++++++++++++++++-------------------
2 files changed, 45 insertions(+), 25 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 5bc7823..a7251b0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -6,5 +6,5 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-bumpalo = { version = "3.9.1", features = [] }
+bumpalo = { version = "3.9.1" }
rand = "0.8.4"
diff --git a/src/lib.rs b/src/lib.rs
index 9ce4e36..23c3f6a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -11,35 +11,44 @@ pub enum Status
{
Winner(P),
}
+pub trait PlayerState: Eq + Copy {
+ fn next(self) -> Self;
+}
+
pub trait GameState: Clone + std::fmt::Debug {
- type Player: Eq + Copy + std::fmt::Debug;
+ type Player: PlayerState + std::fmt::Debug;
fn next_states(&self) -> Box>;
fn status(&self) -> Status;
+ fn toggle_player(&mut self);
+
fn next_random_play(&mut self);
}
mod mcts {
- use crate::{GameState, Status};
+ use crate::{GameState, PlayerState, Status};
use bumpalo::Bump;
use rand::Rng;
use std::cell::Cell;
#[derive(Debug, Clone)]
- struct Node<'tree, S> {
+ struct Node<'tree, S: GameState> {
state: S,
+ // todo: don't have this field and let the GameState handle this all
+ player: S::Player,
visited: Cell,
score: Cell,
parent: Option<&'tree Node<'tree, S>>,
children: Cell<&'tree [Node<'tree, S>]>,
}
- impl<'tree, S> Node<'tree, S> {
- fn new(state: S, alloc: &'tree Bump) -> Node {
+ impl<'tree, S: GameState> Node<'tree, S> {
+ fn new(state: S, player: S::Player, alloc: &'tree Bump) -> Node<'tree, S> {
Node {
state,
+ player,
visited: Cell::new(0),
score: Cell::new(0),
parent: None,
@@ -64,26 +73,31 @@ mod mcts {
const MAX_TRIES: u64 = 10_000;
- pub fn find_next_move(current_state: S, opponent: S::Player) -> S {
+ pub fn find_next_move(current_state: S, own_player: S::Player) -> S {
let alloc = Bump::new();
+ let opponent = own_player.next();
- let root_node = alloc.alloc(Node::new(current_state, &alloc));
+ let root_node = alloc.alloc(Node::new(current_state, opponent, &alloc));
for _ in 0..MAX_TRIES {
+ // Phase 1 - Selection
let promising_node = select_promising_node(root_node);
+ // Phase 2 - Expansion
if promising_node.state.status() == Status::InProgress {
expand_node(&alloc, promising_node);
}
- if !promising_node.children.get().is_empty() {
- let child = promising_node.random_child();
- let playout_result = simulate_random_playout(child, opponent);
- back_propagation(child, playout_result);
+ // Phase 3 - Simulation
+ let promising_node = if !promising_node.children.get().is_empty() {
+ promising_node.random_child()
} else {
- let playout_result = simulate_random_playout(promising_node, opponent);
- back_propagation(promising_node, playout_result);
+ promising_node
};
+ let playout_result = simulate_random_playout(promising_node, opponent);
+
+ // Phase 4 - Update
+ back_propagation(promising_node, playout_result);
}
let winner_node = root_node.child_with_max_score();
@@ -92,7 +106,9 @@ mod mcts {
node.state.clone()
}
- fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
+ fn select_promising_node<'tree, S: GameState>(
+ root_node: &'tree Node<'tree, S>,
+ ) -> &'tree Node<'tree, S> {
let mut node = root_node;
while !node.children.get().is_empty() {
@@ -107,6 +123,7 @@ mod mcts {
let new_nodes = possible_states.map(|state| Node {
state,
+ player: node.player.next(),
visited: Cell::new(0),
score: Cell::new(0),
parent: Some(node),
@@ -148,6 +165,7 @@ mod mcts {
}
while board_status == Status::InProgress {
+ state.toggle_player();
state.next_random_play();
board_status = state.status();
}
@@ -157,6 +175,7 @@ mod mcts {
mod uct {
use crate::mcts::Node;
+ use crate::GameState;
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
if node_visit == 0 {
@@ -170,7 +189,7 @@ mod mcts {
num as u32
}
- pub(super) fn find_best_node_with_uct<'tree, S>(
+ pub(super) fn find_best_node_with_uct<'tree, S: GameState>(
node: &'tree Node<'tree, S>,
) -> Option<&'tree Node<'tree, S>> {
let parent_visit_count = node.visited.get();
@@ -184,7 +203,7 @@ mod mcts {
}
pub mod tic_tac_toe {
- use crate::{GameState, Status};
+ use crate::{GameState, PlayerState, Status};
use rand::Rng;
use std::fmt::{Display, Formatter, Write};
@@ -194,10 +213,8 @@ pub mod tic_tac_toe {
X,
}
- impl std::ops::Not for Player {
- type Output = Self;
-
- fn not(self) -> Self::Output {
+ impl crate::PlayerState for Player {
+ fn next(self) -> Self {
match self {
Self::O => Self::X,
Self::X => Self::O,
@@ -265,7 +282,7 @@ pub mod tic_tac_toe {
.map(|(i, _)| {
let mut new_state = *self;
- new_state.active_player = !self.active_player;
+ new_state.active_player = self.active_player.next();
new_state.board[i] = new_state.active_player.into();
new_state
@@ -306,9 +323,11 @@ pub mod tic_tac_toe {
Status::InProgress
}
- fn next_random_play(&mut self) {
- self.active_player = !self.active_player;
+ fn toggle_player(&mut self) {
+ self.active_player = self.active_player.next();
+ }
+ fn next_random_play(&mut self) {
let free_fields = self.free_fields();
let random_field = rand::thread_rng().gen_range(0..free_fields);
@@ -352,6 +371,7 @@ pub mod tic_tac_toe {
use std::io::Write;
const PLAYING_PLAYER: Player = Player::O;
+ const AI_PLAYER: Player = Player::X;
pub fn main() {
let mut board = Board::new(PLAYING_PLAYER);
@@ -364,7 +384,7 @@ pub mod tic_tac_toe {
break result;
}
- let ai_play = mcts::find_next_move(board, PLAYING_PLAYER);
+ let ai_play = mcts::find_next_move(board, AI_PLAYER);
board = ai_play;
if let Some(result) = is_finished(&board) {