From bdd90a41c601c760fa1e7e18f243115b86815811 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Sun, 23 Jan 2022 21:53:17 +0100 Subject: [PATCH] boilerplate --- Cargo.toml | 1 + src/basic_search.rs | 33 ++++++++------ src/lib.rs | 105 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ebbdb6..c71cee9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,4 +6,5 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bumpalo = { version = "3.9.1", features = ["collections"] } rand = "0.8.4" diff --git a/src/basic_search.rs b/src/basic_search.rs index a1c3f12..3311ee8 100644 --- a/src/basic_search.rs +++ b/src/basic_search.rs @@ -1,4 +1,3 @@ -use crate::Node; use std::collections::VecDeque; struct Node { @@ -6,32 +5,38 @@ struct Node { children: Vec>, } -pub fn breadth_first_search(tree: &Node, searched: &T) -> bool { +#[macro_export] +macro_rules! tree { + ($first:expr $(, $($rest:expr),*)?) => { + $crate::basic_search::Node { + value: $first, + children: vec![$($($rest),*)?], + } + }; +} + +fn breadth_first_search(tree: &Node, searched: &T) -> bool { let mut candidates = VecDeque::new(); candidates.push_back(tree); - loop { - if let Some(candidate) = candidates.pop_front() { - if candidate.value == *searched { - return true; - } - - candidates.extend(candidate.children.iter()); - } else { - break; + while let Some(candidate) = candidates.pop_front() { + if candidate.value == *searched { + return true; } + + candidates.extend(candidate.children.iter()); } false } -pub fn depth_first_search(tree: &Node, searched: &T) -> bool { +fn depth_first_search(tree: &Node, searched: &T) -> bool { if tree.value == *searched { return true; } for child in &tree.children { - if depth_first_search(&child, searched) { + if depth_first_search(child, searched) { return true; } } @@ -41,7 +46,7 @@ pub fn depth_first_search(tree: &Node, searched: &T) -> bool { #[cfg(test)] mod tests { - use crate::basic_search::{self, Node}; + use crate::basic_search; use crate::tree; #[test] diff --git a/src/lib.rs b/src/lib.rs index f724af7..83d2737 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ #![allow(dead_code)] +use bumpalo::collections::Vec; +use bumpalo::Bump; + use rand::Rng; mod basic_search; @@ -7,38 +10,92 @@ mod basic_search; pub trait GameState { fn points(&self) -> i32; - fn next(&self) -> Box>; + fn next(&self) -> &[Self] + where + Self: Sized; + + fn is_finished(&self) -> bool; } -impl GameState for i32 { - fn points(&self) -> i32 { - *self - } +type Tree<'tree, S> = Node<'tree, S>; - fn next(&self) -> Box> { - let child_amount = rand::thread_rng().gen_range(0..10); - let mut i = 0; - Box::new(std::iter::from_fn(move || { - if i < child_amount { - Some(rand::random()) - } else { - None - } - })) - } +struct Node<'tree, S> { + state: S, + visited: u64, + won: u64, + parent: Option<&'tree Node<'tree, S>>, + children: Vec<'tree, Node<'tree, S>>, } -#[macro_export] -macro_rules! tree { - ($first:expr $(, $($rest:expr),*)?) => { - $crate::Node { - value: $first, - children: vec![$($($rest),*)?], +impl<'tree, S: GameState> Node<'tree, S> { + fn new(state: S, alloc: &'tree Bump) -> Node { + Node { + state, + visited: 0, + won: 0, + parent: None, + children: Vec::new_in(alloc), } - }; + } + + fn random_child(&self) -> &Self { + &self.children[rand::thread_rng().gen_range(0..self.children.len())] + } + + fn into_child_with_max_score(self) -> Self { + todo!() + } } -mod mcts {} +mod mcts { + use crate::{GameState, Node}; + use bumpalo::Bump; + + const MAX_TRIES: u64 = 10000; + + pub fn find_next_move(current_state: S) -> S { + let alloc = Bump::new(); + + let root_node = Node::new(current_state, &alloc); + + for _ in 0..MAX_TRIES { + let promising_node = select_promising_node(&root_node); + + if !promising_node.state.is_finished() { + expand_node(promising_node); + } + + let node_to_explore = if !node_to_explore.children.is_empty() { + promising_node.random_child() + } else { + promising_node + }; + + let playout_result = simulate_random_playout(node_to_explore); + back_propagation(node_to_explore, playout_result); + } + + let winner_node = root_node.into_child_with_max_score(); + + winner_node.state + } + + fn select_promising_node<'tree, S>(node: &'tree Node<'_, S>) -> &'tree Node<'tree, S> { + todo!() + } + + fn expand_node(node: &Node) { + todo!() + } + + fn simulate_random_playout(node: &Node<'_, S>) -> u64 { + todo!() + } + + fn back_propagation(node: &Node<'_, S>, playout_result: u64) { + todo!() + } +} #[cfg(test)] mod test {