mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
pain
This commit is contained in:
parent
e9ad50b96e
commit
a9171b7724
1 changed files with 18 additions and 16 deletions
34
src/lib.rs
34
src/lib.rs
|
|
@ -6,7 +6,7 @@ use bumpalo::collections::Vec;
|
|||
use bumpalo::Bump;
|
||||
use std::cell::RefCell;
|
||||
|
||||
use rand::Rng;
|
||||
use rand::{random, Rng};
|
||||
|
||||
mod basic_search;
|
||||
|
||||
|
|
@ -45,7 +45,9 @@ impl<'tree, S> Node<'tree, S> {
|
|||
}
|
||||
|
||||
fn random_child(&self) -> &RefCell<Self> {
|
||||
&self.children[rand::thread_rng().gen_range(0..self.children.len())]
|
||||
let random_index = rand::thread_rng().gen_range(0..self.children.len());
|
||||
|
||||
&self.children[random_index]
|
||||
}
|
||||
|
||||
fn into_child_with_max_score(self) -> Option<RefCell<Self>> {
|
||||
|
|
@ -58,7 +60,7 @@ impl<'tree, S> Node<'tree, S> {
|
|||
mod mcts {
|
||||
use crate::{GameState, Node, RefNode};
|
||||
use bumpalo::Bump;
|
||||
use std::cell::RefCell;
|
||||
use std::cell::{Ref, RefCell};
|
||||
|
||||
const MAX_TRIES: u64 = 10000;
|
||||
|
||||
|
|
@ -68,15 +70,15 @@ mod mcts {
|
|||
let root_node = alloc.alloc(RefCell::new(Node::new(current_state, &alloc)));
|
||||
|
||||
for _ in 0..MAX_TRIES {
|
||||
let promising_node_cell = select_promising_node(&root_node);
|
||||
let promising_node = promising_node_cell.borrow();
|
||||
let promising_node = select_promising_node(&root_node);
|
||||
|
||||
if !promising_node.state.is_finished() {
|
||||
if !promising_node.borrow().state.is_finished() {
|
||||
expand_node(&promising_node);
|
||||
}
|
||||
|
||||
if !promising_node.children.is_empty() {
|
||||
let child = promising_node.random_child().borrow();
|
||||
if !promising_node.borrow().children.is_empty() {
|
||||
let promising_node = promising_node.borrow();
|
||||
let child = promising_node.random_child();
|
||||
let playout_result = simulate_random_playout(&child);
|
||||
back_propagation(&child, playout_result);
|
||||
} else {
|
||||
|
|
@ -91,14 +93,13 @@ mod mcts {
|
|||
state
|
||||
}
|
||||
|
||||
fn select_promising_node<'tree, S>(
|
||||
root_node: &'tree RefNode<'tree, S>,
|
||||
) -> &'tree RefNode<'tree, S> {
|
||||
fn select_promising_node<'cell, 'tree, S>(
|
||||
root_node: &'cell RefCell<Node<'tree, S>>,
|
||||
) -> &'cell RefNode<'tree, S> {
|
||||
let mut node = root_node;
|
||||
|
||||
let borrowed_node = node.borrow();
|
||||
while borrowed_node.children.len() != 0 {
|
||||
node = uct::find_best_node_with_uct(&borrowed_node).unwrap()
|
||||
while node.borrow().children.len() != 0 {
|
||||
node = uct::find_best_node_with_uct(&root_node).unwrap()
|
||||
}
|
||||
|
||||
node
|
||||
|
|
@ -119,6 +120,7 @@ mod mcts {
|
|||
mod uct {
|
||||
use crate::mcts::RefNode;
|
||||
use crate::Node;
|
||||
use std::cell::Ref;
|
||||
|
||||
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
|
||||
if node_visit == 0 {
|
||||
|
|
@ -133,8 +135,8 @@ mod mcts {
|
|||
}
|
||||
|
||||
pub fn find_best_node_with_uct<'cell, 'tree, S>(
|
||||
node: &'tree RefNode<'tree, S>,
|
||||
) -> Option<&'tree RefNode<'tree, S>> {
|
||||
node: Ref<'cell, Node<'tree, S>>,
|
||||
) -> Option<&'cell RefNode<'tree, S>> {
|
||||
let parent_visit_count = node.visited;
|
||||
|
||||
node.children.iter().max_by_key(|n| {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue