This commit is contained in:
nora 2022-01-25 14:36:33 +01:00
parent 236827dfe5
commit add6c7ec73
2 changed files with 30 additions and 16 deletions

View file

@ -1,3 +1,5 @@
#![allow(dead_code)]
use std::collections::VecDeque; use std::collections::VecDeque;
struct Node<T> { struct Node<T> {

View file

@ -1,27 +1,23 @@
//! https://www.baeldung.com/java-monte-carlo-tree-search //! https://www.baeldung.com/java-monte-carlo-tree-search
#![allow(dead_code)]
mod basic_search; mod basic_search;
pub use mcts::find_next_move;
pub trait GameState { pub trait GameState {
fn points(&self) -> i32; fn points(&self) -> i32;
fn next(&self) -> &[Self] fn next_states(&self) -> Box<dyn Iterator<Item = Self>>;
where
Self: Sized;
fn is_finished(&self) -> bool; fn is_finished(&self) -> bool;
} }
mod mcts { mod mcts {
use crate::GameState; use crate::GameState;
use bumpalo::collections::Vec; use bumpalo::collections::Vec as BumpVec;
use bumpalo::Bump; use bumpalo::Bump;
use rand::Rng; use rand::Rng;
use std::cell::{Cell, RefCell}; use std::cell::Cell;
type Tree<'tree, S> = RefCell<Node<'tree, S>>;
#[derive(Clone)] #[derive(Clone)]
struct Node<'tree, S> { struct Node<'tree, S> {
@ -29,11 +25,9 @@ mod mcts {
visited: Cell<u64>, visited: Cell<u64>,
score: Cell<u64>, score: Cell<u64>,
parent: Option<&'tree Node<'tree, S>>, parent: Option<&'tree Node<'tree, S>>,
children: Vec<'tree, Node<'tree, S>>, children: BumpVec<'tree, Node<'tree, S>>,
} }
type RefNode<'tree, S> = RefCell<Node<'tree, S>>;
impl<'tree, S> Node<'tree, S> { impl<'tree, S> Node<'tree, S> {
fn new(state: S, alloc: &'tree Bump) -> Node<S> { fn new(state: S, alloc: &'tree Bump) -> Node<S> {
Node { Node {
@ -41,7 +35,7 @@ mod mcts {
visited: Cell::new(0), visited: Cell::new(0),
score: Cell::new(0), score: Cell::new(0),
parent: None, parent: None,
children: Vec::new_in(alloc), children: BumpVec::new_in(alloc),
} }
} }
@ -69,7 +63,7 @@ mod mcts {
let promising_node = select_promising_node(root_node); let promising_node = select_promising_node(root_node);
if !promising_node.state.is_finished() { if !promising_node.state.is_finished() {
expand_node(promising_node); expand_node(&alloc, promising_node);
} }
if !promising_node.children.is_empty() { if !promising_node.children.is_empty() {
@ -98,8 +92,26 @@ mod mcts {
node node
} }
fn expand_node<S>(_node: &Node<'_, S>) { fn expand_node<S: GameState>(alloc: &Bump, node: &Node<'_, S>) {
todo!("next") /*
List<State> possibleStates = node.getState().getAllPossibleStates();
possibleStates.forEach(state -> {
Node newNode = new Node(state);
newNode.setParent(node);
newNode.getState().setPlayerNo(node.getState().getOpponent());
node.getChildArray().add(newNode);
});
*/
let possible_states = node.state.next_states();
for state in possible_states {
let child = Node {
state,
visited: Cell::new(0),
score: Cell::new(0),
parent: Some(node),
children: BumpVec::new_in(alloc),
};
}
} }
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 { fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {