mirror of
https://github.com/Noratrieb/monte-carlo-tree-search.git
synced 2026-01-14 15:25:09 +01:00
cool types
This commit is contained in:
parent
ae7c64c624
commit
da89f02007
1 changed files with 86 additions and 34 deletions
120
src/lib.rs
120
src/lib.rs
|
|
@ -4,12 +4,14 @@ mod basic_search;
|
|||
|
||||
pub use mcts::find_next_move;
|
||||
|
||||
pub trait GameState {
|
||||
fn points(&self) -> i32;
|
||||
pub trait GameState: Clone {
|
||||
type Player: Eq + Copy;
|
||||
|
||||
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>>;
|
||||
|
||||
fn is_finished(&self) -> bool;
|
||||
fn player_won(&self) -> Option<Self::Player>;
|
||||
|
||||
fn next_random_play(&mut self);
|
||||
}
|
||||
|
||||
mod mcts {
|
||||
|
|
@ -21,8 +23,8 @@ mod mcts {
|
|||
#[derive(Clone)]
|
||||
struct Node<'tree, S> {
|
||||
state: S,
|
||||
visited: Cell<u64>,
|
||||
score: Cell<u64>,
|
||||
visited: Cell<u32>,
|
||||
score: Cell<i32>,
|
||||
parent: Option<&'tree Node<'tree, S>>,
|
||||
children: Cell<&'tree [Node<'tree, S>]>,
|
||||
}
|
||||
|
|
@ -55,7 +57,7 @@ mod mcts {
|
|||
|
||||
const MAX_TRIES: u64 = 10000;
|
||||
|
||||
pub fn find_next_move<S: GameState + Clone>(current_state: S) -> S {
|
||||
pub fn find_next_move<S: GameState>(current_state: S, opponent: S::Player) -> S {
|
||||
let alloc = Bump::new();
|
||||
|
||||
let root_node = alloc.alloc(Node::new(current_state, &alloc));
|
||||
|
|
@ -63,16 +65,16 @@ mod mcts {
|
|||
for _ in 0..MAX_TRIES {
|
||||
let promising_node = select_promising_node(root_node);
|
||||
|
||||
if !promising_node.state.is_finished() {
|
||||
if promising_node.state.player_won() == None {
|
||||
expand_node(&alloc, promising_node);
|
||||
}
|
||||
|
||||
if !promising_node.children.get().is_empty() {
|
||||
let child = promising_node.random_child();
|
||||
let playout_result = simulate_random_playout(child);
|
||||
let playout_result = simulate_random_playout(child, opponent);
|
||||
back_propagation(child, playout_result);
|
||||
} else {
|
||||
let playout_result = simulate_random_playout(promising_node);
|
||||
let playout_result = simulate_random_playout(promising_node, opponent);
|
||||
back_propagation(promising_node, playout_result);
|
||||
};
|
||||
}
|
||||
|
|
@ -109,48 +111,56 @@ mod mcts {
|
|||
node.children.set(children);
|
||||
}
|
||||
|
||||
fn back_propagation<S>(node: &Node<'_, S>, _playout_result: u64) {
|
||||
fn back_propagation<S: GameState>(node: &Node<'_, S>, player_won: S::Player) {
|
||||
let mut temp_node = Some(node);
|
||||
|
||||
while let Some(node) = temp_node {
|
||||
// todo increment visit
|
||||
// todo increment win count if we won
|
||||
node.visited.set(node.visited.get() + 1);
|
||||
|
||||
if node.state.player_won() == Some(player_won) {
|
||||
node.score.set(node.score.get() + 1);
|
||||
}
|
||||
|
||||
temp_node = node.parent;
|
||||
}
|
||||
}
|
||||
|
||||
fn simulate_random_playout<S>(_node: &Node<'_, S>) -> u64 {
|
||||
/*
|
||||
Node tempNode = new Node(node);
|
||||
State tempState = tempNode.getState();
|
||||
int boardStatus = tempState.getBoard().checkStatus();
|
||||
if (boardStatus == opponent) {
|
||||
tempNode.getParent().getState().setWinScore(Integer.MIN_VALUE);
|
||||
return boardStatus;
|
||||
fn simulate_random_playout<S: GameState>(node: &Node<'_, S>, opponent: S::Player) -> S::Player {
|
||||
let mut state = node.state.clone();
|
||||
|
||||
let mut board_status = state.player_won();
|
||||
|
||||
if board_status == Some(opponent) {
|
||||
if let Some(parent) = node.parent {
|
||||
parent.score.set(i32::MIN)
|
||||
}
|
||||
return opponent;
|
||||
}
|
||||
while (boardStatus == Board.IN_PROGRESS) {
|
||||
tempState.togglePlayer();
|
||||
tempState.randomPlay();
|
||||
boardStatus = tempState.getBoard().checkStatus();
|
||||
|
||||
loop {
|
||||
match board_status {
|
||||
None => {
|
||||
state.next_random_play();
|
||||
board_status = state.player_won();
|
||||
}
|
||||
Some(player) => return player,
|
||||
}
|
||||
}
|
||||
return boardStatus;
|
||||
*/
|
||||
todo!()
|
||||
}
|
||||
|
||||
mod uct {
|
||||
use crate::mcts::Node;
|
||||
|
||||
pub fn uct(total_visit: u64, win_score: u64, node_visit: u64) -> u64 {
|
||||
pub fn uct(total_visit: u32, win_score: i32, node_visit: i32) -> u32 {
|
||||
if node_visit == 0 {
|
||||
return u64::MAX;
|
||||
return u32::MAX;
|
||||
}
|
||||
|
||||
let num = (win_score / node_visit) as f64
|
||||
+ std::f64::consts::SQRT_2
|
||||
* f64::sqrt((total_visit as f64).ln() / node_visit as f64);
|
||||
|
||||
num as u64
|
||||
num as u32
|
||||
}
|
||||
|
||||
pub(super) fn find_best_node_with_uct<'tree, S>(
|
||||
|
|
@ -166,8 +176,50 @@ mod mcts {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[test]
|
||||
fn t() {}
|
||||
mod tic_tac_toe {
|
||||
use crate::GameState;
|
||||
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum Player {
|
||||
O,
|
||||
X,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum State {
|
||||
Empty,
|
||||
X,
|
||||
O,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
struct Board {
|
||||
active_player: Player,
|
||||
board: [State; 9],
|
||||
}
|
||||
|
||||
impl Board {
|
||||
pub fn new(starter: Player) -> Self {
|
||||
Self {
|
||||
active_player: starter,
|
||||
board: [State::Empty; 9],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GameState for Board {
|
||||
type Player = Player;
|
||||
|
||||
fn next_states(&self) -> Box<dyn ExactSizeIterator<Item = Self>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn player_won(&self) -> Option<Player> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn next_random_play(&mut self) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue