mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
boilerplate
This commit is contained in:
parent
68abd05376
commit
bdd90a41c6
3 changed files with 101 additions and 38 deletions
|
|
@ -6,4 +6,5 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bumpalo = { version = "3.9.1", features = ["collections"] }
|
||||
rand = "0.8.4"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
use crate::Node;
|
||||
use std::collections::VecDeque;
|
||||
|
||||
struct Node<T> {
|
||||
|
|
@ -6,32 +5,38 @@ struct Node<T> {
|
|||
children: Vec<Node<T>>,
|
||||
}
|
||||
|
||||
pub fn breadth_first_search<T: Eq>(tree: &Node<T>, searched: &T) -> bool {
|
||||
#[macro_export]
|
||||
macro_rules! tree {
|
||||
($first:expr $(, $($rest:expr),*)?) => {
|
||||
$crate::basic_search::Node {
|
||||
value: $first,
|
||||
children: vec![$($($rest),*)?],
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn breadth_first_search<T: Eq>(tree: &Node<T>, searched: &T) -> bool {
|
||||
let mut candidates = VecDeque::new();
|
||||
candidates.push_back(tree);
|
||||
|
||||
loop {
|
||||
if let Some(candidate) = candidates.pop_front() {
|
||||
while let Some(candidate) = candidates.pop_front() {
|
||||
if candidate.value == *searched {
|
||||
return true;
|
||||
}
|
||||
|
||||
candidates.extend(candidate.children.iter());
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn depth_first_search<T: Eq>(tree: &Node<T>, searched: &T) -> bool {
|
||||
fn depth_first_search<T: Eq>(tree: &Node<T>, searched: &T) -> bool {
|
||||
if tree.value == *searched {
|
||||
return true;
|
||||
}
|
||||
|
||||
for child in &tree.children {
|
||||
if depth_first_search(&child, searched) {
|
||||
if depth_first_search(child, searched) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -41,7 +46,7 @@ pub fn depth_first_search<T: Eq>(tree: &Node<T>, searched: &T) -> bool {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::basic_search::{self, Node};
|
||||
use crate::basic_search;
|
||||
use crate::tree;
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
105
src/lib.rs
105
src/lib.rs
|
|
@ -1,5 +1,8 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
mod basic_search;
|
||||
|
|
@ -7,38 +10,92 @@ mod basic_search;
|
|||
pub trait GameState {
|
||||
fn points(&self) -> i32;
|
||||
|
||||
fn next(&self) -> Box<dyn Iterator<Item = Self>>;
|
||||
fn next(&self) -> &[Self]
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
fn is_finished(&self) -> bool;
|
||||
}
|
||||
|
||||
impl GameState for i32 {
|
||||
fn points(&self) -> i32 {
|
||||
*self
|
||||
type Tree<'tree, S> = Node<'tree, S>;
|
||||
|
||||
struct Node<'tree, S> {
|
||||
state: S,
|
||||
visited: u64,
|
||||
won: u64,
|
||||
parent: Option<&'tree Node<'tree, S>>,
|
||||
children: Vec<'tree, Node<'tree, S>>,
|
||||
}
|
||||
|
||||
fn next(&self) -> Box<dyn Iterator<Item = Self>> {
|
||||
let child_amount = rand::thread_rng().gen_range(0..10);
|
||||
let mut i = 0;
|
||||
Box::new(std::iter::from_fn(move || {
|
||||
if i < child_amount {
|
||||
Some(rand::random())
|
||||
impl<'tree, S: GameState> Node<'tree, S> {
|
||||
fn new(state: S, alloc: &'tree Bump) -> Node<S> {
|
||||
Node {
|
||||
state,
|
||||
visited: 0,
|
||||
won: 0,
|
||||
parent: None,
|
||||
children: Vec::new_in(alloc),
|
||||
}
|
||||
}
|
||||
|
||||
fn random_child(&self) -> &Self {
|
||||
&self.children[rand::thread_rng().gen_range(0..self.children.len())]
|
||||
}
|
||||
|
||||
fn into_child_with_max_score(self) -> Self {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
mod mcts {
|
||||
use crate::{GameState, Node};
|
||||
use bumpalo::Bump;
|
||||
|
||||
const MAX_TRIES: u64 = 10000;
|
||||
|
||||
pub fn find_next_move<S: GameState>(current_state: S) -> S {
|
||||
let alloc = Bump::new();
|
||||
|
||||
let root_node = Node::new(current_state, &alloc);
|
||||
|
||||
for _ in 0..MAX_TRIES {
|
||||
let promising_node = select_promising_node(&root_node);
|
||||
|
||||
if !promising_node.state.is_finished() {
|
||||
expand_node(promising_node);
|
||||
}
|
||||
|
||||
let node_to_explore = if !node_to_explore.children.is_empty() {
|
||||
promising_node.random_child()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tree {
|
||||
($first:expr $(, $($rest:expr),*)?) => {
|
||||
$crate::Node {
|
||||
value: $first,
|
||||
children: vec![$($($rest),*)?],
|
||||
}
|
||||
promising_node
|
||||
};
|
||||
|
||||
let playout_result = simulate_random_playout(node_to_explore);
|
||||
back_propagation(node_to_explore, playout_result);
|
||||
}
|
||||
|
||||
mod mcts {}
|
||||
let winner_node = root_node.into_child_with_max_score();
|
||||
|
||||
winner_node.state
|
||||
}
|
||||
|
||||
fn select_promising_node<'tree, S>(node: &'tree Node<'_, S>) -> &'tree Node<'tree, S> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn expand_node<S>(node: &Node<S>) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn simulate_random_playout<S>(node: &Node<'_, S>) -> u64 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn back_propagation<S>(node: &Node<'_, S>, playout_result: u64) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue