finally, cells!

This commit is contained in:
nora 2022-01-25 15:48:20 +01:00
parent add6c7ec73
commit dff9ca7c2d

View file

@ -7,14 +7,13 @@ pub use mcts::find_next_move;
pub trait GameState {
fn points(&self) -> i32;
fn next_states(&self) -> Box<dyn Iterator<Item = Self>>;
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
fn is_finished(&self) -> bool;
}
mod mcts {
use crate::GameState;
use bumpalo::collections::Vec as BumpVec;
use bumpalo::Bump;
use rand::Rng;
use std::cell::Cell;
@ -25,7 +24,7 @@ mod mcts {
visited: Cell<u64>,
score: Cell<u64>,
parent: Option<&'tree Node<'tree, S>>,
children: BumpVec<'tree, Node<'tree, S>>,
children: Cell<&'tree [Node<'tree, S>]>,
}
impl<'tree, S> Node<'tree, S> {
@ -35,19 +34,21 @@ mod mcts {
visited: Cell::new(0),
score: Cell::new(0),
parent: None,
children: BumpVec::new_in(alloc),
children: Cell::new(alloc.alloc([])),
}
}
fn random_child(&self) -> &Self {
let random_index = rand::thread_rng().gen_range(0..self.children.len());
let children = self.children.get();
let random_index = rand::thread_rng().gen_range(0..children.len());
&self.children[random_index]
&children[random_index]
}
fn into_child_with_max_score(self) -> Option<Self> {
fn child_with_max_score(&self) -> Option<&Self> {
self.children
.into_iter()
.get()
.iter()
.max_by_key(|node| node.score.get())
}
}
@ -66,7 +67,7 @@ mod mcts {
expand_node(&alloc, promising_node);
}
if !promising_node.children.is_empty() {
if !promising_node.children.get().is_empty() {
let child = promising_node.random_child();
let playout_result = simulate_random_playout(child);
back_propagation(child, playout_result);
@ -76,42 +77,36 @@ mod mcts {
};
}
let winner_node = root_node.clone().into_child_with_max_score();
let winner_node = root_node.child_with_max_score();
let node = winner_node.unwrap();
node.state
node.state.clone()
}
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
let mut node = root_node;
while !node.children.is_empty() {
while !node.children.get().is_empty() {
node = uct::find_best_node_with_uct(root_node).unwrap()
}
node
}
fn expand_node<S: GameState>(alloc: &Bump, node: &Node<'_, S>) {
/*
List<State> possibleStates = node.getState().getAllPossibleStates();
possibleStates.forEach(state -> {
Node newNode = new Node(state);
newNode.setParent(node);
newNode.getState().setPlayerNo(node.getState().getOpponent());
node.getChildArray().add(newNode);
});
*/
fn expand_node<'tree, S: GameState>(alloc: &'tree Bump, node: &'tree Node<'tree, S>) {
let possible_states = node.state.next_states();
for state in possible_states {
let child = Node {
let new_nodes = possible_states.map(|state| Node {
state,
visited: Cell::new(0),
score: Cell::new(0),
parent: Some(node),
children: BumpVec::new_in(alloc),
};
}
children: Cell::new(alloc.alloc([])),
});
let children = alloc.alloc_slice_fill_iter(new_nodes);
node.children.set(children);
}
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
@ -143,6 +138,7 @@ mod mcts {
let parent_visit_count = node.visited.get();
node.children
.get()
.iter()
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
}