mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
trying to fix this shit :(
This commit is contained in:
parent
36128c22c8
commit
feac668bb8
2 changed files with 45 additions and 25 deletions
|
|
@ -6,5 +6,5 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bumpalo = { version = "3.9.1", features = [] }
|
||||
bumpalo = { version = "3.9.1" }
|
||||
rand = "0.8.4"
|
||||
|
|
|
|||
68
src/lib.rs
68
src/lib.rs
|
|
@ -11,35 +11,44 @@ pub enum Status<P> {
|
|||
Winner(P),
|
||||
}
|
||||
|
||||
pub trait PlayerState: Eq + Copy {
|
||||
fn next(self) -> Self;
|
||||
}
|
||||
|
||||
pub trait GameState: Clone + std::fmt::Debug {
|
||||
type Player: Eq + Copy + std::fmt::Debug;
|
||||
type Player: PlayerState + std::fmt::Debug;
|
||||
|
||||
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
|
||||
|
||||
fn status(&self) -> Status<Self::Player>;
|
||||
|
||||
fn toggle_player(&mut self);
|
||||
|
||||
fn next_random_play(&mut self);
|
||||
}
|
||||
|
||||
mod mcts {
|
||||
use crate::{GameState, Status};
|
||||
use crate::{GameState, PlayerState, Status};
|
||||
use bumpalo::Bump;
|
||||
use rand::Rng;
|
||||
use std::cell::Cell;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Node<'tree, S> {
|
||||
struct Node<'tree, S: GameState> {
|
||||
state: S,
|
||||
// todo: don't have this field and let the GameState handle this all
|
||||
player: S::Player,
|
||||
visited: Cell<u32>,
|
||||
score: Cell<i32>,
|
||||
parent: Option<&'tree Node<'tree, S>>,
|
||||
children: Cell<&'tree [Node<'tree, S>]>,
|
||||
}
|
||||
|
||||
impl<'tree, S> Node<'tree, S> {
|
||||
fn new(state: S, alloc: &'tree Bump) -> Node<S> {
|
||||
impl<'tree, S: GameState> Node<'tree, S> {
|
||||
fn new(state: S, player: S::Player, alloc: &'tree Bump) -> Node<'tree, S> {
|
||||
Node {
|
||||
state,
|
||||
player,
|
||||
visited: Cell::new(0),
|
||||
score: Cell::new(0),
|
||||
parent: None,
|
||||
|
|
@ -64,26 +73,31 @@ mod mcts {
|
|||
|
||||
const MAX_TRIES: u64 = 10_000;
|
||||
|
||||
pub fn find_next_move<S: GameState>(current_state: S, opponent: S::Player) -> S {
|
||||
pub fn find_next_move<S: GameState>(current_state: S, own_player: S::Player) -> S {
|
||||
let alloc = Bump::new();
|
||||
let opponent = own_player.next();
|
||||
|
||||
let root_node = alloc.alloc(Node::new(current_state, &alloc));
|
||||
let root_node = alloc.alloc(Node::new(current_state, opponent, &alloc));
|
||||
|
||||
for _ in 0..MAX_TRIES {
|
||||
// Phase 1 - Selection
|
||||
let promising_node = select_promising_node(root_node);
|
||||
|
||||
// Phase 2 - Expansion
|
||||
if promising_node.state.status() == Status::InProgress {
|
||||
expand_node(&alloc, promising_node);
|
||||
}
|
||||
|
||||
if !promising_node.children.get().is_empty() {
|
||||
let child = promising_node.random_child();
|
||||
let playout_result = simulate_random_playout(child, opponent);
|
||||
back_propagation(child, playout_result);
|
||||
// Phase 3 - Simulation
|
||||
let promising_node = if !promising_node.children.get().is_empty() {
|
||||
promising_node.random_child()
|
||||
} else {
|
||||
let playout_result = simulate_random_playout(promising_node, opponent);
|
||||
back_propagation(promising_node, playout_result);
|
||||
promising_node
|
||||
};
|
||||
let playout_result = simulate_random_playout(promising_node, opponent);
|
||||
|
||||
// Phase 4 - Update
|
||||
back_propagation(promising_node, playout_result);
|
||||
}
|
||||
|
||||
let winner_node = root_node.child_with_max_score();
|
||||
|
|
@ -92,7 +106,9 @@ mod mcts {
|
|||
node.state.clone()
|
||||
}
|
||||
|
||||
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
|
||||
fn select_promising_node<'tree, S: GameState>(
|
||||
root_node: &'tree Node<'tree, S>,
|
||||
) -> &'tree Node<'tree, S> {
|
||||
let mut node = root_node;
|
||||
|
||||
while !node.children.get().is_empty() {
|
||||
|
|
@ -107,6 +123,7 @@ mod mcts {
|
|||
|
||||
let new_nodes = possible_states.map(|state| Node {
|
||||
state,
|
||||
player: node.player.next(),
|
||||
visited: Cell::new(0),
|
||||
score: Cell::new(0),
|
||||
parent: Some(node),
|
||||
|
|
@ -148,6 +165,7 @@ mod mcts {
|
|||
}
|
||||
|
||||
while board_status == Status::InProgress {
|
||||
state.toggle_player();
|
||||
state.next_random_play();
|
||||
board_status = state.status();
|
||||
}
|
||||
|
|
@ -157,6 +175,7 @@ mod mcts {
|
|||
|
||||
mod uct {
|
||||
use crate::mcts::Node;
|
||||
use crate::GameState;
|
||||
|
||||
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
|
||||
if node_visit == 0 {
|
||||
|
|
@ -170,7 +189,7 @@ mod mcts {
|
|||
num as u32
|
||||
}
|
||||
|
||||
pub(super) fn find_best_node_with_uct<'tree, S>(
|
||||
pub(super) fn find_best_node_with_uct<'tree, S: GameState>(
|
||||
node: &'tree Node<'tree, S>,
|
||||
) -> Option<&'tree Node<'tree, S>> {
|
||||
let parent_visit_count = node.visited.get();
|
||||
|
|
@ -184,7 +203,7 @@ mod mcts {
|
|||
}
|
||||
|
||||
pub mod tic_tac_toe {
|
||||
use crate::{GameState, Status};
|
||||
use crate::{GameState, PlayerState, Status};
|
||||
use rand::Rng;
|
||||
use std::fmt::{Display, Formatter, Write};
|
||||
|
||||
|
|
@ -194,10 +213,8 @@ pub mod tic_tac_toe {
|
|||
X,
|
||||
}
|
||||
|
||||
impl std::ops::Not for Player {
|
||||
type Output = Self;
|
||||
|
||||
fn not(self) -> Self::Output {
|
||||
impl crate::PlayerState for Player {
|
||||
fn next(self) -> Self {
|
||||
match self {
|
||||
Self::O => Self::X,
|
||||
Self::X => Self::O,
|
||||
|
|
@ -265,7 +282,7 @@ pub mod tic_tac_toe {
|
|||
.map(|(i, _)| {
|
||||
let mut new_state = *self;
|
||||
|
||||
new_state.active_player = !self.active_player;
|
||||
new_state.active_player = self.active_player.next();
|
||||
new_state.board[i] = new_state.active_player.into();
|
||||
|
||||
new_state
|
||||
|
|
@ -306,9 +323,11 @@ pub mod tic_tac_toe {
|
|||
Status::InProgress
|
||||
}
|
||||
|
||||
fn next_random_play(&mut self) {
|
||||
self.active_player = !self.active_player;
|
||||
fn toggle_player(&mut self) {
|
||||
self.active_player = self.active_player.next();
|
||||
}
|
||||
|
||||
fn next_random_play(&mut self) {
|
||||
let free_fields = self.free_fields();
|
||||
let random_field = rand::thread_rng().gen_range(0..free_fields);
|
||||
|
||||
|
|
@ -352,6 +371,7 @@ pub mod tic_tac_toe {
|
|||
use std::io::Write;
|
||||
|
||||
const PLAYING_PLAYER: Player = Player::O;
|
||||
const AI_PLAYER: Player = Player::X;
|
||||
|
||||
pub fn main() {
|
||||
let mut board = Board::new(PLAYING_PLAYER);
|
||||
|
|
@ -364,7 +384,7 @@ pub mod tic_tac_toe {
|
|||
break result;
|
||||
}
|
||||
|
||||
let ai_play = mcts::find_next_move(board, PLAYING_PLAYER);
|
||||
let ai_play = mcts::find_next_move(board, AI_PLAYER);
|
||||
board = ai_play;
|
||||
|
||||
if let Some(result) = is_finished(&board) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue