finally, cells!

This commit is contained in:
nora 2022-01-25 15:48:20 +01:00
parent add6c7ec73
commit dff9ca7c2d

View file

@ -7,14 +7,13 @@ pub use mcts::find_next_move;
pub trait GameState { pub trait GameState {
fn points(&self) -> i32; fn points(&self) -> i32;
fn next_states(&self) -> Box<dyn Iterator<Item = Self>>; fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn is_finished(&self) -> bool; fn is_finished(&self) -> bool;
} }
mod mcts { mod mcts {
use crate::GameState; use crate::GameState;
use bumpalo::collections::Vec as BumpVec;
use bumpalo::Bump; use bumpalo::Bump;
use rand::Rng; use rand::Rng;
use std::cell::Cell; use std::cell::Cell;
@ -25,7 +24,7 @@ mod mcts {
visited: Cell<u64>, visited: Cell<u64>,
score: Cell<u64>, score: Cell<u64>,
parent: Option<&'tree Node<'tree, S>>, parent: Option<&'tree Node<'tree, S>>,
children: BumpVec<'tree, Node<'tree, S>>, children: Cell<&'tree [Node<'tree, S>]>,
} }
impl<'tree, S> Node<'tree, S> { impl<'tree, S> Node<'tree, S> {
@ -35,19 +34,21 @@ mod mcts {
visited: Cell::new(0), visited: Cell::new(0),
score: Cell::new(0), score: Cell::new(0),
parent: None, parent: None,
children: BumpVec::new_in(alloc), children: Cell::new(alloc.alloc([])),
} }
} }
fn random_child(&self) -> &Self { fn random_child(&self) -> &Self {
let random_index = rand::thread_rng().gen_range(0..self.children.len()); let children = self.children.get();
let random_index = rand::thread_rng().gen_range(0..children.len());
&self.children[random_index] &children[random_index]
} }
fn into_child_with_max_score(self) -> Option<Self> { fn child_with_max_score(&self) -> Option<&Self> {
self.children self.children
.into_iter() .get()
.iter()
.max_by_key(|node| node.score.get()) .max_by_key(|node| node.score.get())
} }
} }
@ -66,7 +67,7 @@ mod mcts {
expand_node(&alloc, promising_node); expand_node(&alloc, promising_node);
} }
if !promising_node.children.is_empty() { if !promising_node.children.get().is_empty() {
let child = promising_node.random_child(); let child = promising_node.random_child();
let playout_result = simulate_random_playout(child); let playout_result = simulate_random_playout(child);
back_propagation(child, playout_result); back_propagation(child, playout_result);
@ -76,42 +77,36 @@ mod mcts {
}; };
} }
let winner_node = root_node.clone().into_child_with_max_score(); let winner_node = root_node.child_with_max_score();
let node = winner_node.unwrap(); let node = winner_node.unwrap();
node.state node.state.clone()
} }
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> { fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
let mut node = root_node; let mut node = root_node;
while !node.children.is_empty() { while !node.children.get().is_empty() {
node = uct::find_best_node_with_uct(root_node).unwrap() node = uct::find_best_node_with_uct(root_node).unwrap()
} }
node node
} }
fn expand_node<S: GameState>(alloc: &Bump, node: &Node<'_, S>) { fn expand_node<'tree, S: GameState>(alloc: &'tree Bump, node: &'tree Node<'tree, S>) {
/*
List<State> possibleStates = node.getState().getAllPossibleStates();
possibleStates.forEach(state -> {
Node newNode = new Node(state);
newNode.setParent(node);
newNode.getState().setPlayerNo(node.getState().getOpponent());
node.getChildArray().add(newNode);
});
*/
let possible_states = node.state.next_states(); let possible_states = node.state.next_states();
for state in possible_states {
let child = Node { let new_nodes = possible_states.map(|state| Node {
state, state,
visited: Cell::new(0), visited: Cell::new(0),
score: Cell::new(0), score: Cell::new(0),
parent: Some(node), parent: Some(node),
children: BumpVec::new_in(alloc), children: Cell::new(alloc.alloc([])),
}; });
}
let children = alloc.alloc_slice_fill_iter(new_nodes);
node.children.set(children);
} }
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 { fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
@ -143,6 +138,7 @@ mod mcts {
let parent_visit_count = node.visited.get(); let parent_visit_count = node.visited.get();
node.children node.children
.get()
.iter() .iter()
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get())) .max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
} }