mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 07:15:07 +01:00
finally, cells!
This commit is contained in:
parent
add6c7ec73
commit
dff9ca7c2d
1 changed files with 27 additions and 31 deletions
58
src/lib.rs
58
src/lib.rs
|
|
@ -7,14 +7,13 @@ pub use mcts::find_next_move;
|
|||
pub trait GameState {
|
||||
fn points(&self) -> i32;
|
||||
|
||||
fn next_states(&self) -> Box<dyn Iterator<Item = Self>>;
|
||||
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
|
||||
|
||||
fn is_finished(&self) -> bool;
|
||||
}
|
||||
|
||||
mod mcts {
|
||||
use crate::GameState;
|
||||
use bumpalo::collections::Vec as BumpVec;
|
||||
use bumpalo::Bump;
|
||||
use rand::Rng;
|
||||
use std::cell::Cell;
|
||||
|
|
@ -25,7 +24,7 @@ mod mcts {
|
|||
visited: Cell<u64>,
|
||||
score: Cell<u64>,
|
||||
parent: Option<&'tree Node<'tree, S>>,
|
||||
children: BumpVec<'tree, Node<'tree, S>>,
|
||||
children: Cell<&'tree [Node<'tree, S>]>,
|
||||
}
|
||||
|
||||
impl<'tree, S> Node<'tree, S> {
|
||||
|
|
@ -35,19 +34,21 @@ mod mcts {
|
|||
visited: Cell::new(0),
|
||||
score: Cell::new(0),
|
||||
parent: None,
|
||||
children: BumpVec::new_in(alloc),
|
||||
children: Cell::new(alloc.alloc([])),
|
||||
}
|
||||
}
|
||||
|
||||
fn random_child(&self) -> &Self {
|
||||
let random_index = rand::thread_rng().gen_range(0..self.children.len());
|
||||
let children = self.children.get();
|
||||
let random_index = rand::thread_rng().gen_range(0..children.len());
|
||||
|
||||
&self.children[random_index]
|
||||
&children[random_index]
|
||||
}
|
||||
|
||||
fn into_child_with_max_score(self) -> Option<Self> {
|
||||
fn child_with_max_score(&self) -> Option<&Self> {
|
||||
self.children
|
||||
.into_iter()
|
||||
.get()
|
||||
.iter()
|
||||
.max_by_key(|node| node.score.get())
|
||||
}
|
||||
}
|
||||
|
|
@ -66,7 +67,7 @@ mod mcts {
|
|||
expand_node(&alloc, promising_node);
|
||||
}
|
||||
|
||||
if !promising_node.children.is_empty() {
|
||||
if !promising_node.children.get().is_empty() {
|
||||
let child = promising_node.random_child();
|
||||
let playout_result = simulate_random_playout(child);
|
||||
back_propagation(child, playout_result);
|
||||
|
|
@ -76,42 +77,36 @@ mod mcts {
|
|||
};
|
||||
}
|
||||
|
||||
let winner_node = root_node.clone().into_child_with_max_score();
|
||||
let winner_node = root_node.child_with_max_score();
|
||||
|
||||
let node = winner_node.unwrap();
|
||||
node.state
|
||||
node.state.clone()
|
||||
}
|
||||
|
||||
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
|
||||
let mut node = root_node;
|
||||
|
||||
while !node.children.is_empty() {
|
||||
while !node.children.get().is_empty() {
|
||||
node = uct::find_best_node_with_uct(root_node).unwrap()
|
||||
}
|
||||
|
||||
node
|
||||
}
|
||||
|
||||
fn expand_node<S: GameState>(alloc: &Bump, node: &Node<'_, S>) {
|
||||
/*
|
||||
List<State> possibleStates = node.getState().getAllPossibleStates();
|
||||
possibleStates.forEach(state -> {
|
||||
Node newNode = new Node(state);
|
||||
newNode.setParent(node);
|
||||
newNode.getState().setPlayerNo(node.getState().getOpponent());
|
||||
node.getChildArray().add(newNode);
|
||||
});
|
||||
*/
|
||||
fn expand_node<'tree, S: GameState>(alloc: &'tree Bump, node: &'tree Node<'tree, S>) {
|
||||
let possible_states = node.state.next_states();
|
||||
for state in possible_states {
|
||||
let child = Node {
|
||||
state,
|
||||
visited: Cell::new(0),
|
||||
score: Cell::new(0),
|
||||
parent: Some(node),
|
||||
children: BumpVec::new_in(alloc),
|
||||
};
|
||||
}
|
||||
|
||||
let new_nodes = possible_states.map(|state| Node {
|
||||
state,
|
||||
visited: Cell::new(0),
|
||||
score: Cell::new(0),
|
||||
parent: Some(node),
|
||||
children: Cell::new(alloc.alloc([])),
|
||||
});
|
||||
|
||||
let children = alloc.alloc_slice_fill_iter(new_nodes);
|
||||
|
||||
node.children.set(children);
|
||||
}
|
||||
|
||||
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
|
||||
|
|
@ -143,6 +138,7 @@ mod mcts {
|
|||
let parent_visit_count = node.visited.get();
|
||||
|
||||
node.children
|
||||
.get()
|
||||
.iter()
|
||||
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue