From 11a94fefdf478f0481ec038c7caa8f0d73b49249 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Mon, 24 Jan 2022 22:04:41 +0100 Subject: [PATCH] why clion --- src/lib.rs | 65 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0741f1e..c592c10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ use bumpalo::collections::Vec; use bumpalo::Bump; +use std::cell::RefCell; use rand::Rng; @@ -19,14 +20,14 @@ pub trait GameState { fn is_finished(&self) -> bool; } -type Tree<'tree, S> = Node<'tree, S>; +type Tree<'tree, S> = RefCell>; struct Node<'tree, S> { state: S, visited: u64, - won: u64, - parent: Option<&'tree Node<'tree, S>>, - children: Vec<'tree, Node<'tree, S>>, + score: u64, + parent: Option<&'tree RefCell>>, + children: Vec<'tree, RefCell>>, } impl<'tree, S: GameState> Node<'tree, S> { @@ -34,56 +35,70 @@ impl<'tree, S: GameState> Node<'tree, S> { Node { state, visited: 0, - won: 0, + score: 0, parent: None, children: Vec::new_in(alloc), } } - fn random_child(&self) -> &Self { + fn random_child(&self) -> &RefCell { &self.children[rand::thread_rng().gen_range(0..self.children.len())] } - fn into_child_with_max_score(self) -> Self { - todo!() + fn into_child_with_max_score(self) -> Option> { + self.children + .into_iter() + .max_by_key(|node| node.borrow().score) } } mod mcts { use crate::{GameState, Node}; use bumpalo::Bump; + use std::cell::RefCell; const MAX_TRIES: u64 = 10000; pub fn find_next_move(current_state: S) -> S { let alloc = Bump::new(); - let root_node = Node::new(current_state, &alloc); + let root_node = RefCell::new(Node::new(current_state, &alloc)); for _ in 0..MAX_TRIES { - let promising_node = select_promising_node(&root_node); + let promising_node_cell = select_promising_node(&root_node); + let promising_node = promising_node_cell.borrow(); if !promising_node.state.is_finished() { - expand_node(promising_node); + expand_node(&promising_node); } - let node_to_explore = if !node_to_explore.children.is_empty() { + let node_to_explore = if !promising_node.children.is_empty() { promising_node.random_child() } else { - promising_node + promising_node_cell }; - let playout_result = simulate_random_playout(node_to_explore); - back_propagation(node_to_explore, playout_result); + let playout_result = simulate_random_playout(&node_to_explore.borrow()); + back_propagation(&node_to_explore.borrow(), playout_result); } - let winner_node = root_node.into_child_with_max_score(); + let winner_node = root_node.into_inner().into_child_with_max_score(); - winner_node.state + let state = winner_node.unwrap().into_inner().state; + state } - fn select_promising_node<'tree, S>(node: &'tree Node<'_, S>) -> &'tree Node<'tree, S> { - todo!() + fn select_promising_node<'tree, S>( + root_node: &'tree RefCell>, + ) -> &'tree RefCell> { + let mut node = root_node; + + let borrowed_node = node.borrow(); + while borrowed_node.children.len() != 0 { + node = uct::find_best_node_with_uct(&borrowed_node).unwrap() + } + + node } fn expand_node(node: &Node) { @@ -100,6 +115,7 @@ mod mcts { mod uct { use crate::Node; + use std::cell::RefCell; pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { if node_visit == 0 { @@ -113,12 +129,15 @@ mod mcts { num as u64 } - pub fn find_best_node_with_uct(node: &Node) -> Option<&Node> { + pub fn find_best_node_with_uct<'tree, S>( + node: &'tree Node<'tree, S>, + ) -> Option<&'tree RefCell>> { let parent_visit_count = node.visited; - node.children - .iter() - .max_by_key(|n| uct(parent_visit_count, n.won, n.visited)) + node.children.iter().max_by_key(|n| { + let n = n.borrow(); + uct(parent_visit_count, n.score, n.visited) + }) } } }