diff --git a/src/lib.rs b/src/lib.rs index 1fee0e2..8703cf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,12 +2,6 @@ #![allow(dead_code)] -use bumpalo::collections::Vec; -use bumpalo::Bump; -use std::cell::RefCell; - -use rand::{random, Rng}; - mod basic_search; pub trait GameState { @@ -20,107 +14,104 @@ pub trait GameState { fn is_finished(&self) -> bool; } -type Tree<'tree, S> = RefCell>; +mod mcts { + use crate::GameState; + use bumpalo::collections::Vec; + use bumpalo::Bump; + use rand::Rng; + use std::cell::{Cell, RefCell}; -#[derive(Clone)] -struct Node<'tree, S> { - state: S, - visited: u64, - score: u64, - parent: Option<&'tree RefNode<'tree, S>>, - children: Vec<'tree, RefNode<'tree, S>>, -} + type Tree<'tree, S> = RefCell>; -type RefNode<'tree, S> = RefCell>; + #[derive(Clone)] + struct Node<'tree, S> { + state: S, + visited: Cell, + score: Cell, + parent: Option<&'tree Node<'tree, S>>, + children: Vec<'tree, Node<'tree, S>>, + } -impl<'tree, S> Node<'tree, S> { - fn new(state: S, alloc: &'tree Bump) -> Node { - Node { - state, - visited: 0, - score: 0, - parent: None, - children: Vec::new_in(alloc), + type RefNode<'tree, S> = RefCell>; + + impl<'tree, S> Node<'tree, S> { + fn new(state: S, alloc: &'tree Bump) -> Node { + Node { + state, + visited: Cell::new(0), + score: Cell::new(0), + parent: None, + children: Vec::new_in(alloc), + } + } + + fn random_child(&self) -> &Self { + let random_index = rand::thread_rng().gen_range(0..self.children.len()); + + &self.children[random_index] + } + + fn into_child_with_max_score(self) -> Option { + self.children + .into_iter() + .max_by_key(|node| node.score.get()) } } - fn random_child(&self) -> &RefCell { - let random_index = rand::thread_rng().gen_range(0..self.children.len()); - - &self.children[random_index] - } - - fn into_child_with_max_score(self) -> Option> { - self.children - .into_iter() - .max_by_key(|node| node.borrow().score) - } -} - -mod mcts { - use crate::{GameState, Node, RefNode}; - use bumpalo::Bump; - use std::cell::{Ref, RefCell}; - const MAX_TRIES: u64 = 10000; pub fn find_next_move(current_state: S) -> S { let alloc = Bump::new(); - let root_node = alloc.alloc(RefCell::new(Node::new(current_state, &alloc))); + let root_node = alloc.alloc(Node::new(current_state, &alloc)); for _ in 0..MAX_TRIES { - let promising_node = select_promising_node(&root_node); + let promising_node = select_promising_node(root_node); - if !promising_node.borrow().state.is_finished() { - expand_node(&promising_node); + if !promising_node.state.is_finished() { + expand_node(promising_node); } - if !promising_node.borrow().children.is_empty() { - let promising_node = promising_node.borrow(); + if !promising_node.children.is_empty() { let child = promising_node.random_child(); - let playout_result = simulate_random_playout(&child); - back_propagation(&child, playout_result); + let playout_result = simulate_random_playout(child); + back_propagation(child, playout_result); } else { - let playout_result = simulate_random_playout(&promising_node); - back_propagation(&promising_node, playout_result); + let playout_result = simulate_random_playout(promising_node); + back_propagation(promising_node, playout_result); }; } - let winner_node = root_node.clone().into_inner().into_child_with_max_score(); + let winner_node = root_node.clone().into_child_with_max_score(); - let state = winner_node.unwrap().into_inner().state; - state + let node = winner_node.unwrap(); + node.state } - fn select_promising_node<'cell, 'tree, S>( - root_node: &'cell RefCell>, - ) -> &'cell RefNode<'tree, S> { + fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> { let mut node = root_node; - while node.borrow().children.len() != 0 { - node = uct::find_best_node_with_uct(&root_node).unwrap() + while !node.children.is_empty() { + node = uct::find_best_node_with_uct(root_node).unwrap() } node } - fn expand_node(_node: &RefNode) { + fn expand_node(_node: &Node<'_, S>) { todo!("next") } - fn simulate_random_playout(_node: &RefNode<'_, S>) -> u64 { + fn simulate_random_playout(_node: &Node<'_, S>) -> u64 { todo!() } - fn back_propagation(_node: &RefNode<'_, S>, _playout_result: u64) { + fn back_propagation(_node: &Node<'_, S>, _playout_result: u64) { todo!() } mod uct { - use crate::mcts::RefNode; - use crate::Node; - use std::cell::Ref; + use crate::mcts::Node; pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { if node_visit == 0 { @@ -134,15 +125,14 @@ mod mcts { num as u64 } - pub fn find_best_node_with_uct<'cell, 'tree, S>( - node: Ref<'cell, Node<'tree, S>>, - ) -> Option<&'cell RefNode<'tree, S>> { - let parent_visit_count = node.visited; + pub(super) fn find_best_node_with_uct<'tree, S>( + node: &'tree Node<'tree, S>, + ) -> Option<&'tree Node<'tree, S>> { + let parent_visit_count = node.visited.get(); - node.children.iter().max_by_key(|n| { - let n = n.borrow(); - uct(parent_visit_count, n.score, n.visited) - }) + node.children + .iter() + .max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get())) } } }