From e9ad50b96e65a9e77799f508eb3616fd7f918f3c Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Mon, 24 Jan 2022 22:31:20 +0100 Subject: [PATCH] RefNode --- src/lib.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d6f2f31..b5d1260 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,11 +27,13 @@ struct Node<'tree, S> { state: S, visited: u64, score: u64, - parent: Option<&'tree RefCell>>, - children: Vec<'tree, RefCell>>, + parent: Option<&'tree RefNode<'tree, S>>, + children: Vec<'tree, RefNode<'tree, S>>, } -impl<'tree, S: GameState> Node<'tree, S> { +type RefNode<'tree, S> = RefCell>; + +impl<'tree, S> Node<'tree, S> { fn new(state: S, alloc: &'tree Bump) -> Node { Node { state, @@ -54,7 +56,7 @@ impl<'tree, S: GameState> Node<'tree, S> { } mod mcts { - use crate::{GameState, Node}; + use crate::{GameState, Node, RefNode}; use bumpalo::Bump; use std::cell::RefCell; @@ -90,8 +92,8 @@ mod mcts { } fn select_promising_node<'tree, S>( - root_node: &'tree RefCell>, - ) -> &'tree RefCell> { + root_node: &'tree RefNode<'tree, S>, + ) -> &'tree RefNode<'tree, S> { let mut node = root_node; let borrowed_node = node.borrow(); @@ -102,21 +104,21 @@ mod mcts { node } - fn expand_node(_node: &Node) { + fn expand_node(_node: &RefNode) { todo!("next") } - fn simulate_random_playout(_node: &Node<'_, S>) -> u64 { + fn simulate_random_playout(_node: &RefNode<'_, S>) -> u64 { todo!() } - fn back_propagation(_node: &Node<'_, S>, _playout_result: u64) { + fn back_propagation(_node: &RefNode<'_, S>, _playout_result: u64) { todo!() } mod uct { + use crate::mcts::RefNode; use crate::Node; - use std::cell::RefCell; pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { if node_visit == 0 { @@ -131,8 +133,8 @@ mod mcts { } pub fn find_best_node_with_uct<'cell, 'tree, S>( - node: &'tree Node<'tree, S>, - ) -> Option<&'tree RefCell>> { + node: &'tree RefNode<'tree, S>, + ) -> Option<&'tree RefNode<'tree, S>> { let parent_visit_count = node.visited; node.children.iter().max_by_key(|n| {