mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
finally, cells!
This commit is contained in:
parent
add6c7ec73
commit
dff9ca7c2d
1 changed files with 27 additions and 31 deletions
50
src/lib.rs
50
src/lib.rs
|
|
@ -7,14 +7,13 @@ pub use mcts::find_next_move;
|
||||||
pub trait GameState {
|
pub trait GameState {
|
||||||
fn points(&self) -> i32;
|
fn points(&self) -> i32;
|
||||||
|
|
||||||
fn next_states(&self) -> Box<dyn Iterator<Item = Self>>;
|
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
|
||||||
|
|
||||||
fn is_finished(&self) -> bool;
|
fn is_finished(&self) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
mod mcts {
|
mod mcts {
|
||||||
use crate::GameState;
|
use crate::GameState;
|
||||||
use bumpalo::collections::Vec as BumpVec;
|
|
||||||
use bumpalo::Bump;
|
use bumpalo::Bump;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::cell::Cell;
|
use std::cell::Cell;
|
||||||
|
|
@ -25,7 +24,7 @@ mod mcts {
|
||||||
visited: Cell<u64>,
|
visited: Cell<u64>,
|
||||||
score: Cell<u64>,
|
score: Cell<u64>,
|
||||||
parent: Option<&'tree Node<'tree, S>>,
|
parent: Option<&'tree Node<'tree, S>>,
|
||||||
children: BumpVec<'tree, Node<'tree, S>>,
|
children: Cell<&'tree [Node<'tree, S>]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tree, S> Node<'tree, S> {
|
impl<'tree, S> Node<'tree, S> {
|
||||||
|
|
@ -35,19 +34,21 @@ mod mcts {
|
||||||
visited: Cell::new(0),
|
visited: Cell::new(0),
|
||||||
score: Cell::new(0),
|
score: Cell::new(0),
|
||||||
parent: None,
|
parent: None,
|
||||||
children: BumpVec::new_in(alloc),
|
children: Cell::new(alloc.alloc([])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn random_child(&self) -> &Self {
|
fn random_child(&self) -> &Self {
|
||||||
let random_index = rand::thread_rng().gen_range(0..self.children.len());
|
let children = self.children.get();
|
||||||
|
let random_index = rand::thread_rng().gen_range(0..children.len());
|
||||||
|
|
||||||
&self.children[random_index]
|
&children[random_index]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_child_with_max_score(self) -> Option<Self> {
|
fn child_with_max_score(&self) -> Option<&Self> {
|
||||||
self.children
|
self.children
|
||||||
.into_iter()
|
.get()
|
||||||
|
.iter()
|
||||||
.max_by_key(|node| node.score.get())
|
.max_by_key(|node| node.score.get())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -66,7 +67,7 @@ mod mcts {
|
||||||
expand_node(&alloc, promising_node);
|
expand_node(&alloc, promising_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !promising_node.children.is_empty() {
|
if !promising_node.children.get().is_empty() {
|
||||||
let child = promising_node.random_child();
|
let child = promising_node.random_child();
|
||||||
let playout_result = simulate_random_playout(child);
|
let playout_result = simulate_random_playout(child);
|
||||||
back_propagation(child, playout_result);
|
back_propagation(child, playout_result);
|
||||||
|
|
@ -76,42 +77,36 @@ mod mcts {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let winner_node = root_node.clone().into_child_with_max_score();
|
let winner_node = root_node.child_with_max_score();
|
||||||
|
|
||||||
let node = winner_node.unwrap();
|
let node = winner_node.unwrap();
|
||||||
node.state
|
node.state.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
|
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
|
||||||
let mut node = root_node;
|
let mut node = root_node;
|
||||||
|
|
||||||
while !node.children.is_empty() {
|
while !node.children.get().is_empty() {
|
||||||
node = uct::find_best_node_with_uct(root_node).unwrap()
|
node = uct::find_best_node_with_uct(root_node).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
node
|
node
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand_node<S: GameState>(alloc: &Bump, node: &Node<'_, S>) {
|
fn expand_node<'tree, S: GameState>(alloc: &'tree Bump, node: &'tree Node<'tree, S>) {
|
||||||
/*
|
|
||||||
List<State> possibleStates = node.getState().getAllPossibleStates();
|
|
||||||
possibleStates.forEach(state -> {
|
|
||||||
Node newNode = new Node(state);
|
|
||||||
newNode.setParent(node);
|
|
||||||
newNode.getState().setPlayerNo(node.getState().getOpponent());
|
|
||||||
node.getChildArray().add(newNode);
|
|
||||||
});
|
|
||||||
*/
|
|
||||||
let possible_states = node.state.next_states();
|
let possible_states = node.state.next_states();
|
||||||
for state in possible_states {
|
|
||||||
let child = Node {
|
let new_nodes = possible_states.map(|state| Node {
|
||||||
state,
|
state,
|
||||||
visited: Cell::new(0),
|
visited: Cell::new(0),
|
||||||
score: Cell::new(0),
|
score: Cell::new(0),
|
||||||
parent: Some(node),
|
parent: Some(node),
|
||||||
children: BumpVec::new_in(alloc),
|
children: Cell::new(alloc.alloc([])),
|
||||||
};
|
});
|
||||||
}
|
|
||||||
|
let children = alloc.alloc_slice_fill_iter(new_nodes);
|
||||||
|
|
||||||
|
node.children.set(children);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
|
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
|
||||||
|
|
@ -143,6 +138,7 @@ mod mcts {
|
||||||
let parent_visit_count = node.visited.get();
|
let parent_visit_count = node.visited.get();
|
||||||
|
|
||||||
node.children
|
node.children
|
||||||
|
.get()
|
||||||
.iter()
|
.iter()
|
||||||
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
|
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue