mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
why clion
This commit is contained in:
parent
3374c9ff35
commit
11a94fefdf
1 changed files with 42 additions and 23 deletions
65
src/lib.rs
65
src/lib.rs
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
use bumpalo::collections::Vec;
|
||||
use bumpalo::Bump;
|
||||
use std::cell::RefCell;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
|
|
@ -19,14 +20,14 @@ pub trait GameState {
|
|||
fn is_finished(&self) -> bool;
|
||||
}
|
||||
|
||||
type Tree<'tree, S> = Node<'tree, S>;
|
||||
type Tree<'tree, S> = RefCell<Node<'tree, S>>;
|
||||
|
||||
struct Node<'tree, S> {
|
||||
state: S,
|
||||
visited: u64,
|
||||
won: u64,
|
||||
parent: Option<&'tree Node<'tree, S>>,
|
||||
children: Vec<'tree, Node<'tree, S>>,
|
||||
score: u64,
|
||||
parent: Option<&'tree RefCell<Node<'tree, S>>>,
|
||||
children: Vec<'tree, RefCell<Node<'tree, S>>>,
|
||||
}
|
||||
|
||||
impl<'tree, S: GameState> Node<'tree, S> {
|
||||
|
|
@ -34,56 +35,70 @@ impl<'tree, S: GameState> Node<'tree, S> {
|
|||
Node {
|
||||
state,
|
||||
visited: 0,
|
||||
won: 0,
|
||||
score: 0,
|
||||
parent: None,
|
||||
children: Vec::new_in(alloc),
|
||||
}
|
||||
}
|
||||
|
||||
fn random_child(&self) -> &Self {
|
||||
fn random_child(&self) -> &RefCell<Self> {
|
||||
&self.children[rand::thread_rng().gen_range(0..self.children.len())]
|
||||
}
|
||||
|
||||
fn into_child_with_max_score(self) -> Self {
|
||||
todo!()
|
||||
fn into_child_with_max_score(self) -> Option<RefCell<Self>> {
|
||||
self.children
|
||||
.into_iter()
|
||||
.max_by_key(|node| node.borrow().score)
|
||||
}
|
||||
}
|
||||
|
||||
mod mcts {
|
||||
use crate::{GameState, Node};
|
||||
use bumpalo::Bump;
|
||||
use std::cell::RefCell;
|
||||
|
||||
const MAX_TRIES: u64 = 10000;
|
||||
|
||||
pub fn find_next_move<S: GameState>(current_state: S) -> S {
|
||||
let alloc = Bump::new();
|
||||
|
||||
let root_node = Node::new(current_state, &alloc);
|
||||
let root_node = RefCell::new(Node::new(current_state, &alloc));
|
||||
|
||||
for _ in 0..MAX_TRIES {
|
||||
let promising_node = select_promising_node(&root_node);
|
||||
let promising_node_cell = select_promising_node(&root_node);
|
||||
let promising_node = promising_node_cell.borrow();
|
||||
|
||||
if !promising_node.state.is_finished() {
|
||||
expand_node(promising_node);
|
||||
expand_node(&promising_node);
|
||||
}
|
||||
|
||||
let node_to_explore = if !node_to_explore.children.is_empty() {
|
||||
let node_to_explore = if !promising_node.children.is_empty() {
|
||||
promising_node.random_child()
|
||||
} else {
|
||||
promising_node
|
||||
promising_node_cell
|
||||
};
|
||||
|
||||
let playout_result = simulate_random_playout(node_to_explore);
|
||||
back_propagation(node_to_explore, playout_result);
|
||||
let playout_result = simulate_random_playout(&node_to_explore.borrow());
|
||||
back_propagation(&node_to_explore.borrow(), playout_result);
|
||||
}
|
||||
|
||||
let winner_node = root_node.into_child_with_max_score();
|
||||
let winner_node = root_node.into_inner().into_child_with_max_score();
|
||||
|
||||
winner_node.state
|
||||
let state = winner_node.unwrap().into_inner().state;
|
||||
state
|
||||
}
|
||||
|
||||
fn select_promising_node<'tree, S>(node: &'tree Node<'_, S>) -> &'tree Node<'tree, S> {
|
||||
todo!()
|
||||
fn select_promising_node<'tree, S>(
|
||||
root_node: &'tree RefCell<Node<'tree, S>>,
|
||||
) -> &'tree RefCell<Node<'tree, S>> {
|
||||
let mut node = root_node;
|
||||
|
||||
let borrowed_node = node.borrow();
|
||||
while borrowed_node.children.len() != 0 {
|
||||
node = uct::find_best_node_with_uct(&borrowed_node).unwrap()
|
||||
}
|
||||
|
||||
node
|
||||
}
|
||||
|
||||
fn expand_node<S>(node: &Node<S>) {
|
||||
|
|
@ -100,6 +115,7 @@ mod mcts {
|
|||
|
||||
mod uct {
|
||||
use crate::Node;
|
||||
use std::cell::RefCell;
|
||||
|
||||
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
|
||||
if node_visit == 0 {
|
||||
|
|
@ -113,12 +129,15 @@ mod mcts {
|
|||
num as u64
|
||||
}
|
||||
|
||||
pub fn find_best_node_with_uct<S>(node: &Node<S>) -> Option<&Node<S>> {
|
||||
pub fn find_best_node_with_uct<'tree, S>(
|
||||
node: &'tree Node<'tree, S>,
|
||||
) -> Option<&'tree RefCell<Node<'tree, S>>> {
|
||||
let parent_visit_count = node.visited;
|
||||
|
||||
node.children
|
||||
.iter()
|
||||
.max_by_key(|n| uct(parent_visit_count, n.won, n.visited))
|
||||
node.children.iter().max_by_key(|n| {
|
||||
let n = n.borrow();
|
||||
uct(parent_visit_count, n.score, n.visited)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue