mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
cell!
This commit is contained in:
parent
a9171b7724
commit
236827dfe5
1 changed files with 64 additions and 74 deletions
94
src/lib.rs
94
src/lib.rs
|
|
@ -2,12 +2,6 @@
|
||||||
|
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use bumpalo::collections::Vec;
|
|
||||||
use bumpalo::Bump;
|
|
||||||
use std::cell::RefCell;
|
|
||||||
|
|
||||||
use rand::{random, Rng};
|
|
||||||
|
|
||||||
mod basic_search;
|
mod basic_search;
|
||||||
|
|
||||||
pub trait GameState {
|
pub trait GameState {
|
||||||
|
|
@ -20,15 +14,22 @@ pub trait GameState {
|
||||||
fn is_finished(&self) -> bool;
|
fn is_finished(&self) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod mcts {
|
||||||
|
use crate::GameState;
|
||||||
|
use bumpalo::collections::Vec;
|
||||||
|
use bumpalo::Bump;
|
||||||
|
use rand::Rng;
|
||||||
|
use std::cell::{Cell, RefCell};
|
||||||
|
|
||||||
type Tree<'tree, S> = RefCell<Node<'tree, S>>;
|
type Tree<'tree, S> = RefCell<Node<'tree, S>>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct Node<'tree, S> {
|
struct Node<'tree, S> {
|
||||||
state: S,
|
state: S,
|
||||||
visited: u64,
|
visited: Cell<u64>,
|
||||||
score: u64,
|
score: Cell<u64>,
|
||||||
parent: Option<&'tree RefNode<'tree, S>>,
|
parent: Option<&'tree Node<'tree, S>>,
|
||||||
children: Vec<'tree, RefNode<'tree, S>>,
|
children: Vec<'tree, Node<'tree, S>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
type RefNode<'tree, S> = RefCell<Node<'tree, S>>;
|
type RefNode<'tree, S> = RefCell<Node<'tree, S>>;
|
||||||
|
|
@ -37,90 +38,80 @@ impl<'tree, S> Node<'tree, S> {
|
||||||
fn new(state: S, alloc: &'tree Bump) -> Node<S> {
|
fn new(state: S, alloc: &'tree Bump) -> Node<S> {
|
||||||
Node {
|
Node {
|
||||||
state,
|
state,
|
||||||
visited: 0,
|
visited: Cell::new(0),
|
||||||
score: 0,
|
score: Cell::new(0),
|
||||||
parent: None,
|
parent: None,
|
||||||
children: Vec::new_in(alloc),
|
children: Vec::new_in(alloc),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn random_child(&self) -> &RefCell<Self> {
|
fn random_child(&self) -> &Self {
|
||||||
let random_index = rand::thread_rng().gen_range(0..self.children.len());
|
let random_index = rand::thread_rng().gen_range(0..self.children.len());
|
||||||
|
|
||||||
&self.children[random_index]
|
&self.children[random_index]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_child_with_max_score(self) -> Option<RefCell<Self>> {
|
fn into_child_with_max_score(self) -> Option<Self> {
|
||||||
self.children
|
self.children
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.max_by_key(|node| node.borrow().score)
|
.max_by_key(|node| node.score.get())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod mcts {
|
|
||||||
use crate::{GameState, Node, RefNode};
|
|
||||||
use bumpalo::Bump;
|
|
||||||
use std::cell::{Ref, RefCell};
|
|
||||||
|
|
||||||
const MAX_TRIES: u64 = 10000;
|
const MAX_TRIES: u64 = 10000;
|
||||||
|
|
||||||
pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S {
|
pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S {
|
||||||
let alloc = Bump::new();
|
let alloc = Bump::new();
|
||||||
|
|
||||||
let root_node = alloc.alloc(RefCell::new(Node::new(current_state, &alloc)));
|
let root_node = alloc.alloc(Node::new(current_state, &alloc));
|
||||||
|
|
||||||
for _ in 0..MAX_TRIES {
|
for _ in 0..MAX_TRIES {
|
||||||
let promising_node = select_promising_node(&root_node);
|
let promising_node = select_promising_node(root_node);
|
||||||
|
|
||||||
if !promising_node.borrow().state.is_finished() {
|
if !promising_node.state.is_finished() {
|
||||||
expand_node(&promising_node);
|
expand_node(promising_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !promising_node.borrow().children.is_empty() {
|
if !promising_node.children.is_empty() {
|
||||||
let promising_node = promising_node.borrow();
|
|
||||||
let child = promising_node.random_child();
|
let child = promising_node.random_child();
|
||||||
let playout_result = simulate_random_playout(&child);
|
let playout_result = simulate_random_playout(child);
|
||||||
back_propagation(&child, playout_result);
|
back_propagation(child, playout_result);
|
||||||
} else {
|
} else {
|
||||||
let playout_result = simulate_random_playout(&promising_node);
|
let playout_result = simulate_random_playout(promising_node);
|
||||||
back_propagation(&promising_node, playout_result);
|
back_propagation(promising_node, playout_result);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let winner_node = root_node.clone().into_inner().into_child_with_max_score();
|
let winner_node = root_node.clone().into_child_with_max_score();
|
||||||
|
|
||||||
let state = winner_node.unwrap().into_inner().state;
|
let node = winner_node.unwrap();
|
||||||
state
|
node.state
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_promising_node<'cell, 'tree, S>(
|
fn select_promising_node<'tree, S>(root_node: &'tree Node<'tree, S>) -> &'tree Node<'tree, S> {
|
||||||
root_node: &'cell RefCell<Node<'tree, S>>,
|
|
||||||
) -> &'cell RefNode<'tree, S> {
|
|
||||||
let mut node = root_node;
|
let mut node = root_node;
|
||||||
|
|
||||||
while node.borrow().children.len() != 0 {
|
while !node.children.is_empty() {
|
||||||
node = uct::find_best_node_with_uct(&root_node).unwrap()
|
node = uct::find_best_node_with_uct(root_node).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
node
|
node
|
||||||
}
|
}
|
||||||
|
|
||||||
fn expand_node<S>(_node: &RefNode<S>) {
|
fn expand_node<S>(_node: &Node<'_, S>) {
|
||||||
todo!("next")
|
todo!("next")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn simulate_random_playout<S>(_node: &RefNode<'_, S>) -> u64 {
|
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn back_propagation<S>(_node: &RefNode<'_, S>, _playout_result: u64) {
|
fn back_propagation<S>(_node: &Node<'_, S>, _playout_result: u64) {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
mod uct {
|
mod uct {
|
||||||
use crate::mcts::RefNode;
|
use crate::mcts::Node;
|
||||||
use crate::Node;
|
|
||||||
use std::cell::Ref;
|
|
||||||
|
|
||||||
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
|
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
|
||||||
if node_visit == 0 {
|
if node_visit == 0 {
|
||||||
|
|
@ -134,15 +125,14 @@ mod mcts {
|
||||||
num as u64
|
num as u64
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find_best_node_with_uct<'cell, 'tree, S>(
|
pub(super) fn find_best_node_with_uct<'tree, S>(
|
||||||
node: Ref<'cell, Node<'tree, S>>,
|
node: &'tree Node<'tree, S>,
|
||||||
) -> Option<&'cell RefNode<'tree, S>> {
|
) -> Option<&'tree Node<'tree, S>> {
|
||||||
let parent_visit_count = node.visited;
|
let parent_visit_count = node.visited.get();
|
||||||
|
|
||||||
node.children.iter().max_by_key(|n| {
|
node.children
|
||||||
let n = n.borrow();
|
.iter()
|
||||||
uct(parent_visit_count, n.score, n.visited)
|
.max_by_key(|n| uct(parent_visit_count, n.score.get(), n.score.get()))
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue