This commit is contained in:
nora 2022-01-25 13:48:21 +01:00
parent a9171b7724
commit 236827dfe5

View file

@ -2,12 +2,6 @@
#![allow(dead_code)] #![allow(dead_code)]
use bumpalo::collections::Vec;
use bumpalo::Bump;
use std::cell::RefCell;
use rand::{random, Rng};
mod basic_search; mod basic_search;
pub trait GameState { pub trait GameState {
@ -20,15 +14,22 @@ pub trait GameState {
fn is_finished(&self) -> bool; fn is_finished(&self) -> bool;
} }
mod mcts {
use crate::GameState;
use bumpalo::collections::Vec;
use bumpalo::Bump;
use rand::Rng;
use std::cell::{Cell, RefCell};
type Tree<'tree, S> = RefCell<Node<'tree, S>>; type Tree<'tree, S> = RefCell<Node<'tree, S>>;
#[derive(Clone)] #[derive(Clone)]
struct Node<'tree, S> { struct Node<'tree, S> {
state: S, state: S,
visited: u64, visited: Cell<u64>,
score: u64, score: Cell<u64>,
parent: Option<&'tree RefNode<'tree, S>>, parent: Option<&'tree Node<'tree, S>>,
children: Vec<'tree, RefNode<'tree, S>>, children: Vec<'tree, Node<'tree, S>>,
} }
type RefNode<'tree, S> = RefCell<Node<'tree, S>>; type RefNode<'tree, S> = RefCell<Node<'tree, S>>;
@ -37,90 +38,80 @@ impl<'tree, S> Node<'tree, S> {
fn new(state: S, alloc: &'tree Bump) -> Node<S> { fn new(state: S, alloc: &'tree Bump) -> Node<S> {
Node { Node {
state, state,
visited: 0, visited: Cell::new(0),
score: 0, score: Cell::new(0),
parent: None, parent: None,
children: Vec::new_in(alloc), children: Vec::new_in(alloc),
} }
} }
fn random_child(&self) -> &RefCell<Self> { fn random_child(&self) -> &Self {
let random_index = rand::thread_rng().gen_range(0..self.children.len()); let random_index = rand::thread_rng().gen_range(0..self.children.len());
&self.children[random_index] &self.children[random_index]
} }
fn into_child_with_max_score(self) -> Option<RefCell<Self>> { fn into_child_with_max_score(self) -> Option<Self> {
self.children self.children
.into_iter() .into_iter()
.max_by_key(|node| node.borrow().score) .max_by_key(|node| node.score.get())
} }
} }
mod mcts {
use crate::{GameState, Node, RefNode};
use bumpalo::Bump;
use std::cell::{Ref, RefCell};
const MAX_TRIES: u64 = 10000; const MAX_TRIES: u64 = 10000;
pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S { pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S {
let alloc = Bump::new(); let alloc = Bump::new();
let root_node = alloc.alloc(RefCell::new(Node::new(current_state, &alloc))); let root_node = alloc.alloc(Node::new(current_state, &alloc));
for _ in 0..MAX_TRIES { for _ in 0..MAX_TRIES {
let promising_node = select_promising_node(&root_node); let promising_node = select_promising_node(root_node);
if !promising_node.borrow().state.is_finished() { if !promising_node.state.is_finished() {
expand_node(&promising_node); expand_node(promising_node);
} }
if !promising_node.borrow().children.is_empty() { if !promising_node.children.is_empty() {
let promising_node = promising_node.borrow();
let child = promising_node.random_child(); let child = promising_node.random_child();
let playout_result = simulate_random_playout(&child); let playout_result = simulate_random_playout(child);
back_propagation(&child, playout_result); back_propagation(child, playout_result);
} else { } else {
let playout_result = simulate_random_playout(&promising_node); let playout_result = simulate_random_playout(promising_node);
back_propagation(&promising_node, playout_result); back_propagation(promising_node, playout_result);
}; };
} }
let winner_node = root_node.clone().into_inner().into_child_with_max_score(); let winner_node = root_node.clone().into_child_with_max_score();
let state = winner_node.unwrap().into_inner().state; let node = winner_node.unwrap();
state node.state
} }
fn select_promising_node<'cell, 'tree, S>( fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
root_node: &'cell RefCell<Node<'tree, S>>,
) -> &'cell RefNode<'tree, S> {
let mut node = root_node; let mut node = root_node;
while node.borrow().children.len() != 0 { while !node.children.is_empty() {
node = uct::find_best_node_with_uct(&root_node).unwrap() node = uct::find_best_node_with_uct(root_node).unwrap()
} }
node node
} }
fn expand_node<S>(_node: &RefNode<S>) { fn expand_node<S>(_node: &Node<'_, S>) {
todo!("next") todo!("next")
} }
fn simulate_random_playout<S>(_node: &RefNode<'_, S>) -> u64 { fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
todo!() todo!()
} }
fn back_propagation<S>(_node: &RefNode<'_, S>, _playout_result: u64) { fn back_propagation<S>(_node: &Node<'_, S>, _playout_result: u64) {
todo!() todo!()
} }
mod uct { mod uct {
use crate::mcts::RefNode; use crate::mcts::Node;
use crate::Node;
use std::cell::Ref;
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 { pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
if node_visit == 0 { if node_visit == 0 {
@ -134,15 +125,14 @@ mod mcts {
num as u64 num as u64
} }
pub fn find_best_node_with_uct<'cell, 'tree, S>( pub(super) fn find_best_node_with_uct<'tree, S>(
node: Ref<'cell, Node<'tree, S>>, node: &'tree Node<'tree, S>,
) -> Option<&'cell RefNode<'tree, S>> { ) -> Option<&'tree Node<'tree, S>> {
let parent_visit_count = node.visited; let parent_visit_count = node.visited.get();
node.children.iter().max_by_key(|n| { node.children
let n = n.borrow(); .iter()
uct(parent_visit_count, n.score, n.visited) .max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
})
} }
} }
} }