why clion

This commit is contained in:
nora 2022-01-24 22:04:41 +01:00
parent 3374c9ff35
commit 11a94fefdf

View file

@ -4,6 +4,7 @@
use bumpalo::collections::Vec; use bumpalo::collections::Vec;
use bumpalo::Bump; use bumpalo::Bump;
use std::cell::RefCell;
use rand::Rng; use rand::Rng;
@ -19,14 +20,14 @@ pub trait GameState {
fn is_finished(&self) -> bool; fn is_finished(&self) -> bool;
} }
type Tree<'tree, S> = Node<'tree, S>; type Tree<'tree, S> = RefCell<Node<'tree, S>>;
struct Node<'tree, S> { struct Node<'tree, S> {
state: S, state: S,
visited: u64, visited: u64,
won: u64, score: u64,
parent: Option<&'tree Node<'tree, S>>, parent: Option<&'tree RefCell<Node<'tree, S>>>,
children: Vec<'tree, Node<'tree, S>>, children: Vec<'tree, RefCell<Node<'tree, S>>>,
} }
impl<'tree, S: GameState> Node<'tree, S> { impl<'tree, S: GameState> Node<'tree, S> {
@ -34,56 +35,70 @@ impl<'tree, S: GameState> Node<'tree, S> {
Node { Node {
state, state,
visited: 0, visited: 0,
won: 0, score: 0,
parent: None, parent: None,
children: Vec::new_in(alloc), children: Vec::new_in(alloc),
} }
} }
fn random_child(&self) -> &Self { fn random_child(&self) -> &RefCell<Self> {
&self.children[rand::thread_rng().gen_range(0..self.children.len())] &self.children[rand::thread_rng().gen_range(0..self.children.len())]
} }
fn into_child_with_max_score(self) -> Self { fn into_child_with_max_score(self) -> Option<RefCell<Self>> {
todo!() self.children
.into_iter()
.max_by_key(|node| node.borrow().score)
} }
} }
mod mcts { mod mcts {
use crate::{GameState, Node}; use crate::{GameState, Node};
use bumpalo::Bump; use bumpalo::Bump;
use std::cell::RefCell;
const MAX_TRIES: u64 = 10000; const MAX_TRIES: u64 = 10000;
pub fn find_next_move<S: GameState>(current_state: S) -> S { pub fn find_next_move<S: GameState>(current_state: S) -> S {
let alloc = Bump::new(); let alloc = Bump::new();
let root_node = Node::new(current_state, &alloc); let root_node = RefCell::new(Node::new(current_state, &alloc));
for _ in 0..MAX_TRIES { for _ in 0..MAX_TRIES {
let promising_node = select_promising_node(&root_node); let promising_node_cell = select_promising_node(&root_node);
let promising_node = promising_node_cell.borrow();
if !promising_node.state.is_finished() { if !promising_node.state.is_finished() {
expand_node(promising_node); expand_node(&promising_node);
} }
let node_to_explore = if !node_to_explore.children.is_empty() { let node_to_explore = if !promising_node.children.is_empty() {
promising_node.random_child() promising_node.random_child()
} else { } else {
promising_node promising_node_cell
}; };
let playout_result = simulate_random_playout(node_to_explore); let playout_result = simulate_random_playout(&node_to_explore.borrow());
back_propagation(node_to_explore, playout_result); back_propagation(&node_to_explore.borrow(), playout_result);
} }
let winner_node = root_node.into_child_with_max_score(); let winner_node = root_node.into_inner().into_child_with_max_score();
winner_node.state let state = winner_node.unwrap().into_inner().state;
state
} }
fn select_promising_node<'tree, S>(node: &'tree Node<'_, S>) -> &'tree Node<'tree, S> { fn select_promising_node<'tree, S>(
todo!() root_node: &'tree RefCell<Node<'tree, S>>,
) -> &'tree RefCell<Node<'tree, S>> {
let mut node = root_node;
let borrowed_node = node.borrow();
while borrowed_node.children.len() != 0 {
node = uct::find_best_node_with_uct(&borrowed_node).unwrap()
}
node
} }
fn expand_node<S>(node: &Node<S>) { fn expand_node<S>(node: &Node<S>) {
@ -100,6 +115,7 @@ mod mcts {
mod uct { mod uct {
use crate::Node; use crate::Node;
use std::cell::RefCell;
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
if node_visit == 0 { if node_visit == 0 {
@ -113,12 +129,15 @@ mod mcts {
num as u64 num as u64
} }
pub fn find_best_node_with_uct<S>(node: &Node<S>) -> Option<&Node<S>> { pub fn find_best_node_with_uct<'tree, S>(
node: &'tree Node<'tree, S>,
) -> Option<&'tree RefCell<Node<'tree, S>>> {
let parent_visit_count = node.visited; let parent_visit_count = node.visited;
node.children node.children.iter().max_by_key(|n| {
.iter() let n = n.borrow();
.max_by_key(|n| uct(parent_visit_count, n.won, n.visited)) uct(parent_visit_count, n.score, n.visited)
})
} }
} }
} }