From a9171b7724d2bdaf06573a1b7b2115686f8d99d6 Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Mon, 24 Jan 2022 22:41:10 +0100 Subject: [PATCH] pain --- src/lib.rs | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b5d1260..1fee0e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ use bumpalo::collections::Vec; use bumpalo::Bump; use std::cell::RefCell; -use rand::Rng; +use rand::{random, Rng}; mod basic_search; @@ -45,7 +45,9 @@ impl<'tree, S> Node<'tree, S> { } fn random_child(&self) -> &RefCell { - &self.children[rand::thread_rng().gen_range(0..self.children.len())] + let random_index = rand::thread_rng().gen_range(0..self.children.len()); + + &self.children[random_index] } fn into_child_with_max_score(self) -> Option> { @@ -58,7 +60,7 @@ impl<'tree, S> Node<'tree, S> { mod mcts { use crate::{GameState, Node, RefNode}; use bumpalo::Bump; - use std::cell::RefCell; + use std::cell::{Ref, RefCell}; const MAX_TRIES: u64 = 10000; @@ -68,15 +70,15 @@ mod mcts { let root_node = alloc.alloc(RefCell::new(Node::new(current_state, &alloc))); for _ in 0..MAX_TRIES { - let promising_node_cell = select_promising_node(&root_node); - let promising_node = promising_node_cell.borrow(); + let promising_node = select_promising_node(&root_node); - if !promising_node.state.is_finished() { + if !promising_node.borrow().state.is_finished() { expand_node(&promising_node); } - if !promising_node.children.is_empty() { - let child = promising_node.random_child().borrow(); + if !promising_node.borrow().children.is_empty() { + let promising_node = promising_node.borrow(); + let child = promising_node.random_child(); let playout_result = simulate_random_playout(&child); back_propagation(&child, playout_result); } else { @@ -91,14 +93,13 @@ mod mcts { state } - fn select_promising_node<'tree, S>( - root_node: &'tree RefNode<'tree, S>, - ) -> &'tree RefNode<'tree, S> { + fn select_promising_node<'cell, 'tree, S>( + root_node: &'cell RefCell>, + ) -> &'cell RefNode<'tree, S> { let mut node = root_node; - let borrowed_node = node.borrow(); - while borrowed_node.children.len() != 0 { - node = uct::find_best_node_with_uct(&borrowed_node).unwrap() + while node.borrow().children.len() != 0 { + node = uct::find_best_node_with_uct(&root_node).unwrap() } node @@ -119,6 +120,7 @@ mod mcts { mod uct { use crate::mcts::RefNode; use crate::Node; + use std::cell::Ref; pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { if node_visit == 0 { @@ -133,8 +135,8 @@ mod mcts { } pub fn find_best_node_with_uct<'cell, 'tree, S>( - node: &'tree RefNode<'tree, S>, - ) -> Option<&'tree RefNode<'tree, S>> { + node: Ref<'cell, Node<'tree, S>>, + ) -> Option<&'cell RefNode<'tree, S>> { let parent_visit_count = node.visited; node.children.iter().max_by_key(|n| {